I’ve created a model and I’m able to run inference with SVI, and I get good looking results. But when I try to use the Predictive class after inference, I get errors with shapes as it seems Predictive is trying to do batching. I’ve read and reread the indexing tutorial, but I’m not sure how to handle my case, so I’m looking for help.
I’ve attached a toy model that demonstrates what I’m trying to do:
- During inference, samples
r, a, b
are 1D, andgamma
is 2D - I then assemble various matrices for downstream task using
r, a, b, gamma
, requiring some indexing - But when using Predictive, all are 2D (e.g.
r
is(1,5)
instead of just(5)
)
Questions:
- Is there an obvious fix to my model to accommodate the batching?
- Or, is there a way to completely turn off batching in Predictive?
- Or, is there another way to get posterior samples after inference?
Attached runnable code below, thanks in advance.
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.infer.autoguide import AutoMultivariateNormal
from pyro.optim import Adam
def model():
n = 3
n_non_diag = n * (n - 1)
p = 2
n_ts = 5
with pyro.plate('n_a', size=n, dim=-1):
r = pyro.sample("r", dist.Normal(0.0, 1.0))
a = pyro.sample('a', dist.HalfNormal(1.0))
with pyro.plate('non_diag', size=n_non_diag, dim=-1):
b = pyro.sample("b", dist.Normal(0.0, 1.0))
with pyro.plate('n_g', size=n, dim=-2):
with pyro.plate("p", p, dim=-1):
gamma = pyro.sample('gamma', dist.Normal(0.0, 1.0))
# Assemble A matrix
A = torch.zeros(size=(n, n))
mask = torch.eye(n, dtype=torch.bool)
A[mask] = a
A[~mask] = b
# Append a zero vector to gamma
g_0 = torch.zeros(gamma.shape[-2], 1)
gamma = torch.concat([g_0, gamma], axis=-1)
# Some needed indexing
gammas = gamma[:, [0, 1, 0, 2, 0]]
growth = r[:, None] * (1 + gammas)
for idx in range(n_ts):
x_next = growth[:, idx] + A @ torch.ones(n)
# Then obs, etc, omitted.
if __name__ == "__main__":
n_steps = 50
optimizer = Adam({"lr" : 0.001})
guide = AutoMultivariateNormal(
model=model,
init_scale=0.0001,
)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
pyro.clear_param_store()
for step in range(n_steps):
loss = svi.step()
med = guide.median()
print('ran')
predictive_svi = Predictive(
model=model,
guide=guide,
num_samples=500, parallel=False,
return_sites=['r', 'a', 'b', 'gamma',]
)
pred = predictive_svi()