Inference runs as expected, but unexpected shape error using Predictive

I’ve created a model and I’m able to run inference with SVI, and I get good looking results. But when I try to use the Predictive class after inference, I get errors with shapes as it seems Predictive is trying to do batching. I’ve read and reread the indexing tutorial, but I’m not sure how to handle my case, so I’m looking for help.

I’ve attached a toy model that demonstrates what I’m trying to do:

  • During inference, samples r, a, b are 1D, and gamma is 2D
  • I then assemble various matrices for downstream task using r, a, b, gamma, requiring some indexing
  • But when using Predictive, all are 2D (e.g. r is (1,5) instead of just (5))

Questions:

  • Is there an obvious fix to my model to accommodate the batching?
  • Or, is there a way to completely turn off batching in Predictive?
  • Or, is there another way to get posterior samples after inference?

Attached runnable code below, thanks in advance.


import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.infer.autoguide import AutoMultivariateNormal
from pyro.optim import Adam

def model():
    n = 3
    n_non_diag = n * (n - 1)
    p = 2
    n_ts = 5

    with pyro.plate('n_a', size=n, dim=-1):
        r = pyro.sample("r", dist.Normal(0.0, 1.0))
        a = pyro.sample('a', dist.HalfNormal(1.0))

    with pyro.plate('non_diag', size=n_non_diag, dim=-1):
        b = pyro.sample("b", dist.Normal(0.0, 1.0))

    with pyro.plate('n_g', size=n, dim=-2):
        with pyro.plate("p", p, dim=-1):
            gamma = pyro.sample('gamma', dist.Normal(0.0, 1.0))

    # Assemble A matrix
    A = torch.zeros(size=(n, n))
    mask = torch.eye(n, dtype=torch.bool)
    A[mask] = a
    A[~mask] = b

    # Append a zero vector to gamma
    g_0 = torch.zeros(gamma.shape[-2], 1)
    gamma = torch.concat([g_0, gamma], axis=-1)
    
    # Some needed indexing
    gammas = gamma[:, [0, 1, 0, 2, 0]]
    growth = r[:, None] * (1 + gammas)

    for idx in range(n_ts):
        x_next = growth[:, idx] + A @ torch.ones(n)

    # Then obs, etc, omitted.


if __name__ == "__main__":
    n_steps = 50
    optimizer = Adam({"lr" : 0.001})

    guide = AutoMultivariateNormal(
        model=model,
        init_scale=0.0001,
        )
    
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) 

    pyro.clear_param_store()
    for step in range(n_steps):
        loss = svi.step()

    med = guide.median()
    print('ran')

    predictive_svi = Predictive(
        model=model, 
        guide=guide, 
        num_samples=500, parallel=False,
        return_sites=['r', 'a', 'b', 'gamma',]
        )
    
    pred = predictive_svi()