I am trying to use arviz.to_pyro method but getting the following error:
“arviz/data/io_pyro.py:158: UserWarning: Could not get vectorized trace, log_likelihood group will be omitted. Check your model vectorization or set log_likelihood=False”
Here is a minimal reproducible example:
import pyro
import torch
import pyro.distributions as dist
import pyro.infer.mcmc as mcmc
import arviz as az
x = torch.tensor([1.0, 2.0, 3.0, 4, 5])
category = torch.tensor([0, 1, 1, 2, 2]) # Species ID for each data point
y = torch.tensor([2, 6, 8, 11, 13])
n_categories = torch.unique(category).shape[0]
# Define your model with appropriate vectorization
def model(x, categories, y):
beta = pyro.sample('beta', dist.Normal(2, 1))
with pyro.plate('category', n_categories):
intercepts = pyro.sample('intercept', dist.Uniform(0, 4))
with pyro.plate('data', len(x)):
mean = beta*x + intercepts[categories]
pyro.sample('obs', dist.Normal(mean, 1), obs=y)
nuts_kernel = mcmc.NUTS(model)
mcmc_run = mcmc.MCMC(nuts_kernel, num_samples=300, warmup_steps=200)
mcmc_run.run(x, category, y)
az.from_pyro(mcmc_run)
I know the problem is with intercepts[categories] but I can’t get past it. Why does slicing interfere with vectorization, and how can I fix this in a way that allows me to properly use ArviZ without losing the log-likelihood information?
Thanks!