Shape Mismatch in predicting new data point

I am new to Pyro, and bumping into what I assume is a simple error when trying to predict on data not seen by the model during training (i.e. sample from the predictive posterior for new data points). After reading forum help [1-4] and the Pyro codebase [5] I am still stumped. A toy example:

import os
import torch
from torch.distributions import constraints
import numpy as np
import pandas as pd

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive

def model(predictors, pheno):
    n_predictors = predictors.shape[1]
    gamma_0 = pyro.sample("g0", dist.Normal(0., 1.))
    covs = pyro.sample("covs", dist.Normal(torch.zeros(n_predictors), torch.ones(n_predictors)).to_event(1))
    l = (gamma_0 + torch.matmul(predictors, covs.squeeze())).squeeze(-1)
    with pyro.plate("pheno", predictors.shape[0]):
        p = pyro.sample("obs", dist.Bernoulli(logits=l), obs=pheno)
    return p

def guide(predictors, pheno):
    n_predictors = predictors.shape[1]
    g0_loc = pyro.param("g0_loc", torch.tensor(0.))
    g0_scale = pyro.param("g0_scale", torch.tensor(1.), constraint=constraints.positive)
    gamma_0 = pyro.sample("g0", dist.Normal(g0_loc, g0_scale))
    covs_loc = pyro.param("covs_loc", torch.zeros(n_predictors))
    covs = pyro.sample("covs", dist.Normal(covs_loc, 1.).independent(1))

predictors = torch.rand([100, 3])
coefs = torch.tensor([1., -5, 0.2])
pheno = torch.bernoulli(torch.sigmoid(0.2 + torch.matmul(predictors, coefs)))

# optim = pyro.optim.Adam({"lr": 0.01})
num_steps = 5000
lr0 = 0.01
gamma = 0.1
lrd =  gamma ** (1/num_steps)
optim = pyro.optim.ClippedAdam({'lr': lr0, 'lrd': lrd})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

for i in range(num_steps):
    loss = svi.step(predictors, pheno)
    if i % 100 == 0:

predictive = Predictive(model, guide=guide, num_samples=100)
samples = predictive(predictors, pheno)

This all works, and then I go to generate new data and predict on it.

oob_predictors = torch.rand([5, 3])
oob_data = torch.bernoulli(torch.sigmoid(0.2 + torch.matmul(oob_predictors, coefs)))

oob_predictive = Predictive(model, posterior_samples=samples, return_sites=['obs'])
oob_preds = oob_predictive(oob_predictors, None)

## eliding most of the Traceback, but for completeness the Traceback points me to _predictive_sequential() inside the Predictive module.

    site: torch.stack([s[site] for s in collected]).reshape(shape)
RuntimeError: shape '[100, 5]' is invalid for input of size 10000

I do not know how to trouble shoot this: [100,5] seems to be what I want the function to return. I don’t know where this input size of 10000 is coming from, though, as the samples object contains only 400 samples total…:

> samples['covs'].shape
torch.Size([100, 1, 3])
> samples['g0'].shape
torch.Size([100, 1])

Any help is greatly appreciated

[1] Bayesian regression predictions
[2] Fast posterior point estimates without Predictive
[3] Bayesian Hierarchical Linear Regression: how to predict on a new patient having observations?
[4] Out of sample predictions in Pyro
[5] pyro/ at 0029941c2d35f353506301a996a8e6e637c83e6a · pyro-ppl/pyro · GitHub

For anyone facing this, after spending some quality time with the Python debugger I realized my samples from the posterior (not the predictive posterior) included samples of the data, obs. This + the fact that I did not specify anything in return_sites when I tried to sample from the predictive posterior then I think caused Pyro to believe I was for some reason trying to sample obs, which resulted in the dimension mismatch.

This is the correct way to sample predictions on heldout data:

predictive = Predictive(model, guide=guide, num_samples=100, return_sites=['g0', 'covs'])
samples = predictive(predictors, pheno)

oob_predictors = torch.rand([5, 3])
oob_data = torch.bernoulli(torch.sigmoid(0.2 + torch.matmul(oob_predictors, coefs)))

oob_predictive = Predictive(model, posterior_samples=samples, return_sites=['_RETURN'])
oob_preds = oob_predictive(oob_predictors, None)
np.mean(oob_preds['_RETURN'].numpy(), axis=0)
## array([0.15, 0.15, 0.41, 0.28, 0.19], dtype=float32)

(note the specifications of return_sites in both Predictive() objects and the use of None where I had previously passed pheno)

1 Like