Using arviz for pyro

I want to use a few of the plotting abbilities of arviz. I managed to create a few things from pyro but I struggle getting the likelihood in correctly. My model is defined accordingly:

 def model_pyro_logit(x_c, y):
    C = x_c.shape[1]
    alpha = pyro.sample('alpha', dist.Normal(0.0, 1.0))
    beta_c = pyro.sample('lambda', dist.Normal(torch.zeros(C), torch.ones(C)))
    y_loc = (alpha + (x_c * beta_c).sum(dim=1) 
    with pyro.plate('data'):
        pyro.sample('y', dist.Bernoulli(logits=y_loc), obs=y)

#creating model
nuts_kernel = NUTS(model_pyro_logit)
mcmc = MCMC(kernel=nuts_kernel, num_samples=500, warmup_steps=600,
                    num_chains=1)
mcmc.run(x_c, y)
e_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

# creating arviz
posterior_samples = mcmc.get_samples()
posterior_predictive = \
        Predictive(model_pyro_logit, posterior_samples).\
            get_samples(x_c, y=None)
prior = \
        Predictive(FED_models.model_pyro_logit, num_samples=600).\
            get_samples(x_c, y=None)

pyro_data = az.from_pyro(
        mcmc,
        prior=prior,
        posterior_predictive=posterior_predictive, 
        coords={"resp": np.arange(x_c.shape[0])},
        dims={"data": ["resp"]}
    )

However, I get an error message about the likelihood not being able to be defined like this (which has probably something to do with the way I define coords and dims). I tried a few things but nothing seems to work. Any ideas?

@Helena.H Could you try dist.Normal(torch.zeros(C), torch.ones(C)).to_event(1)? Whenever you got stuck with shapes, tensor shape tutorial will help you. :wink:

I still get the following error message:
UserWarning: Could not get vectorized trace, log_likelihood group will be omitted. Check your model vectorization or set log_likelihood=False
"Could not get vectorized trace, log_likelihood group will be omitted. "

I am not sure what causes the issue. Without further information, I just guess you can replace (x_c * beta_c).sum(dim=1) by (x_c * beta_c).sum(dim=-1). Could you provide a full code or post here the full error message?

I applied your suggestion and the error remains. The full error message can be seen below:

You can import warnings and get the traceback by using warnings.simplefilter("error"). If you run the code and you will get the error at (x_c * beta_c).sum(dim=-1). I fixed it with x_c.matmul(beta_c.unsqueeze(-1)).squeeze(-1) :smiley:

import warnings
import arviz as az
import numpy as np
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import NUTS, MCMC, Predictive

def model_pyro_logit(x_c, y):
    C = x_c.shape[1]
    alpha = pyro.sample('alpha', dist.Normal(0.0, 1.0))
    beta_c = pyro.sample('lambda', dist.Normal(torch.zeros(C), torch.ones(C)).to_event(1))
    y_loc = alpha + x_c.matmul(beta_c.unsqueeze(-1)).squeeze(-1)
    with pyro.plate('data', 10):
        pyro.sample('y', dist.Bernoulli(logits=y_loc), obs=y)

#creating model
nuts_kernel = NUTS(model_pyro_logit)
mcmc = MCMC(kernel=nuts_kernel, num_samples=500, warmup_steps=600,
                    num_chains=1)
x_c = torch.randn(10, 3)
y = torch.ones(10)
mcmc.run(x_c, y)
e_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

# creating arviz
posterior_samples = mcmc.get_samples()
posterior_predictive = \
        Predictive(model_pyro_logit, posterior_samples)(x_c, y=None)
prior = \
        Predictive(model_pyro_logit, num_samples=600)(x_c, y=None)

warnings.simplefilter("error")
pyro_data = az.from_pyro(
        mcmc,
        prior=prior,
        posterior_predictive=posterior_predictive, 
        coords={"resp": np.arange(x_c.shape[0])},
        dims={"data": ["resp"]}
    )
1 Like

This solution worked perfectly! Thanks for your help :slight_smile: