# Predictive Sampling with Logistic Regression

Hi,

I’m using V 0.3.4 release and I see based on a recent logistic regression toy example that the predictive function does indeed work as anticipated.

I noticed that the example does make use of `pyro.plate` although I’m not too familiar with when this would be necessary. If this is not the issue, I cannot debug why it’s throwing an error in the toy example I wrote:

`ValueError: Shape mismatch inside plate('_num_predictive_samples') at site ...`

``````import numpy as np
import torch
import pyro
import pyro.distributions as dist
import arviz as az

from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc import NUTS
from pyro.infer.mcmc.util import predictive
from scipy.special import expit

np.random.seed(0)
ndims = 6 # Including intercept
ndata = 4000
X = np.random.randn(ndata, ndims - 1)
X = np.array([np.append(1, x) for x in X]) # Add an intercept
w_ = np.random.randn(ndims)  # hidden
noise_ = 0.1 * np.random.randn(ndata)  # hidden

y_lin_ = X.dot(w_) + noise_ # hidden
y_obs_binom = np.random.binomial(1, expit(y_lin_))

# To estimate w_ with w
def logistic_model(X):
w = pyro.sample('w', dist.Normal(torch.zeros(ndims), torch.ones(ndims)))
y_logit = torch.matmul(X, w)
y = pyro.sample('y', dist.Bernoulli(logits=y_logit),
obs=torch.as_tensor(y_obs_binom, dtype=torch.float32))
return y

nuts_kernel = NUTS(logistic_model, adapt_step_size=True, jit_compile=True, ignore_jit_warnings=True)
logistic_mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500, num_chains=1)
logistic_mcmc.run(torch.as_tensor(X, dtype=torch.float32))

predictive(logistic_model, logistic_mcmc.get_samples(), torch.as_tensor(X, dtype=torch.float32), return_trace=True)
``````

One last question I’ll tack on, I couldn’t get `plot_posterior` from arviz to work with this new MCMC class at `pyro.infer.mcmc.api`; it does run when importing MCMC from `pyro.infer.mcmc`. I’ve resorted to pulling out the trace myself and plotting:

``````def plot_posterior(pyro_mcmc, var):
posterior_dict = dict(
enumerate(pyro_mcmc.get_samples()[var].t().numpy()))
az.plot_posterior(posterior_dict)
``````

Thanks

Hi @jpryda, the `predictive` utility wraps the whole model by an outermost `plate` statement, which declares an additional independent dimension for each `sample` statement. In Pyro, it requires “All dimensions must be declared either dependent or conditionally independent.”, though in HMC, that restriction can be sometimes relaxed. Please checkout tensor shapes tutorial for more details about `plate` statement.

@neerajprad do you have other suggestions on using `predictive` utility?

About arviz, I think your solution is pretty nice. Why do you need something other than that? I would like to modify it a bit (though haven’t tested it yet) with

``````infer_data = az.convert_to_inference_data({k: v.numpy()
for k, v in logistic_mcmc.get_samples(group_by_chain=True).items()})
az.plot_posterior(infer_data)
``````
1 Like

Thanks for the response! This was fixed by changing the model to:

``````def logistic_model(X):
with pyro.plate('plate_w', ndims):
w = pyro.sample('w', dist.Normal(torch.zeros(ndims), torch.ones(ndims)))
with pyro.plate('plate_y'):
y = pyro.sample('y', dist.Bernoulli(logits=torch.matmul(w, X.t())),
obs=torch.as_tensor(y_obs_binom, dtype=torch.float32))
return y
``````

Note that `predictive()` fails when I specify `logits=torch.matmul(X, w)`

@neerajprad do you have other suggestions on using `predictive` utility?

I think all the issues highlighted in this post stem from `plate`s not correctly accounting for the batch dimensions. I am beginning to think that it will be nice to have `predictive` take in a `parallel=True` arg, which when set to `False` does not wrap the model in an outermost `plate` and instead just runs it forward per trace to generate predictions.

This will be more flexible, since with MCMC users won’t have to necessarily use `pyro.plate` in their models.

it will be nice to have `predictive` take in a `parallel=True`

`vmap` will truly win here sadly there is no such functionality in PyTorch >"<