Predictive Sampling with Logistic Regression

Hi,

I’m using V 0.3.4 release and I see based on a recent logistic regression toy example that the predictive function does indeed work as anticipated.

I noticed that the example does make use of pyro.plate although I’m not too familiar with when this would be necessary. If this is not the issue, I cannot debug why it’s throwing an error in the toy example I wrote:

ValueError: Shape mismatch inside plate('_num_predictive_samples') at site ...

import numpy as np
import torch
import pyro
import pyro.distributions as dist
import arviz as az

from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc import NUTS
from pyro.infer.mcmc.util import predictive
from scipy.special import expit

np.random.seed(0)
ndims = 6 # Including intercept
ndata = 4000
X = np.random.randn(ndata, ndims - 1)
X = np.array([np.append(1, x) for x in X]) # Add an intercept
w_ = np.random.randn(ndims)  # hidden
noise_ = 0.1 * np.random.randn(ndata)  # hidden

y_lin_ = X.dot(w_) + noise_ # hidden
y_obs_binom = np.random.binomial(1, expit(y_lin_))

# To estimate w_ with w
def logistic_model(X):
    w = pyro.sample('w', dist.Normal(torch.zeros(ndims), torch.ones(ndims)))
    y_logit = torch.matmul(X, w)
    y = pyro.sample('y', dist.Bernoulli(logits=y_logit),
                    obs=torch.as_tensor(y_obs_binom, dtype=torch.float32))
    return y

nuts_kernel = NUTS(logistic_model, adapt_step_size=True, jit_compile=True, ignore_jit_warnings=True)
logistic_mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500, num_chains=1)
logistic_mcmc.run(torch.as_tensor(X, dtype=torch.float32))

predictive(logistic_model, logistic_mcmc.get_samples(), torch.as_tensor(X, dtype=torch.float32), return_trace=True)

One last question I’ll tack on, I couldn’t get plot_posterior from arviz to work with this new MCMC class at pyro.infer.mcmc.api; it does run when importing MCMC from pyro.infer.mcmc. I’ve resorted to pulling out the trace myself and plotting:

def plot_posterior(pyro_mcmc, var):
    posterior_dict = dict(
        enumerate(pyro_mcmc.get_samples()[var].t().numpy()))
    az.plot_posterior(posterior_dict)

Thanks

Hi @jpryda, the predictive utility wraps the whole model by an outermost plate statement, which declares an additional independent dimension for each sample statement. In Pyro, it requires “All dimensions must be declared either dependent or conditionally independent.”, though in HMC, that restriction can be sometimes relaxed. Please checkout tensor shapes tutorial for more details about plate statement.

@neerajprad do you have other suggestions on using predictive utility?

About arviz, I think your solution is pretty nice. Why do you need something other than that? :slight_smile: I would like to modify it a bit (though haven’t tested it yet) with

infer_data = az.convert_to_inference_data({k: v.numpy()
    for k, v in logistic_mcmc.get_samples(group_by_chain=True).items()})
az.plot_posterior(infer_data)
1 Like

Thanks for the response! This was fixed by changing the model to:

def logistic_model(X):
    with pyro.plate('plate_w', ndims):
        w = pyro.sample('w', dist.Normal(torch.zeros(ndims), torch.ones(ndims)))
    with pyro.plate('plate_y'):
        y = pyro.sample('y', dist.Bernoulli(logits=torch.matmul(w, X.t())),
                        obs=torch.as_tensor(y_obs_binom, dtype=torch.float32))
    return y

Note that predictive() fails when I specify logits=torch.matmul(X, w)

@neerajprad do you have other suggestions on using predictive utility?

I think all the issues highlighted in this post stem from plates not correctly accounting for the batch dimensions. I am beginning to think that it will be nice to have predictive take in a parallel=True arg, which when set to False does not wrap the model in an outermost plate and instead just runs it forward per trace to generate predictions.

This will be more flexible, since with MCMC users won’t have to necessarily use pyro.plate in their models.

it will be nice to have predictive take in a parallel=True

vmap will truly win here :smiley: sadly there is no such functionality in PyTorch >"<