Hi,
I’m using V 0.3.4 release and I see based on a recent logistic regression toy example that the predictive function does indeed work as anticipated.
I noticed that the example does make use of pyro.plate
although I’m not too familiar with when this would be necessary. If this is not the issue, I cannot debug why it’s throwing an error in the toy example I wrote:
ValueError: Shape mismatch inside plate('_num_predictive_samples') at site ...
import numpy as np
import torch
import pyro
import pyro.distributions as dist
import arviz as az
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc import NUTS
from pyro.infer.mcmc.util import predictive
from scipy.special import expit
np.random.seed(0)
ndims = 6 # Including intercept
ndata = 4000
X = np.random.randn(ndata, ndims - 1)
X = np.array([np.append(1, x) for x in X]) # Add an intercept
w_ = np.random.randn(ndims) # hidden
noise_ = 0.1 * np.random.randn(ndata) # hidden
y_lin_ = X.dot(w_) + noise_ # hidden
y_obs_binom = np.random.binomial(1, expit(y_lin_))
# To estimate w_ with w
def logistic_model(X):
w = pyro.sample('w', dist.Normal(torch.zeros(ndims), torch.ones(ndims)))
y_logit = torch.matmul(X, w)
y = pyro.sample('y', dist.Bernoulli(logits=y_logit),
obs=torch.as_tensor(y_obs_binom, dtype=torch.float32))
return y
nuts_kernel = NUTS(logistic_model, adapt_step_size=True, jit_compile=True, ignore_jit_warnings=True)
logistic_mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500, num_chains=1)
logistic_mcmc.run(torch.as_tensor(X, dtype=torch.float32))
predictive(logistic_model, logistic_mcmc.get_samples(), torch.as_tensor(X, dtype=torch.float32), return_trace=True)
One last question I’ll tack on, I couldn’t get plot_posterior
from arviz to work with this new MCMC class at pyro.infer.mcmc.api
; it does run when importing MCMC from pyro.infer.mcmc
. I’ve resorted to pulling out the trace myself and plotting:
def plot_posterior(pyro_mcmc, var):
posterior_dict = dict(
enumerate(pyro_mcmc.get_samples()[var].t().numpy()))
az.plot_posterior(posterior_dict)
Thanks