Hello,
I would like write a mixture model in pyro using parallel enumeration, in which one of my plates is not directly represented in the observed data. However, when computing svi.loss, I get an error that I cannot explain: “ValueError: Error while packing tensors”. I think is either related to the plate not represented in the data or the dimension pyro wishes to use for parallelization.
Below I show the code for a simplified version of the model as well as the lines to create a model/guide and reproduce the error I receive:
import pyro
import torch
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.infer.autoguide import AutoDelta
from pyro import poutine
import pyro.distributions as dist
def check_model(model, *args):
pyro.clear_param_store()
trace = poutine.trace(model).get_trace(*args)
print(trace.format_shapes())
return pyro.render_model(model, model_args=args)
@config_enumerate
def full_model(data):
plate1 = pyro.plate("plate1", 5, dim=-1)
plate2 = pyro.plate("plate2", data.shape[-2], dim=-2)
plate3 = pyro.plate("plate3", data.shape[-3], dim=-3)
with plate3:
A = pyro.sample("A", dist.Normal(torch.tensor([0., 0.2, 0.4, 0.6, 0.8]),
torch.tensor([1., 0.8, 0.6, 0.4, 0.2])))
probs = pyro.sample('probs', dist.Dirichlet(torch.tensor([30., 70.])))
assignment = pyro.sample('assignment', dist.Categorical(probs), infer={"enumerate": "parallel"})
with plate2:
B = pyro.sample("B", dist.Uniform(0., 1.))
prediction = torch.where(assignment==1,
(A*B).sum(-1).unsqueeze(-1),
(A*-B).sum(-1).unsqueeze(-1))
with plate3, plate2:
X = pyro.sample("X", dist.Gamma(torch.exp(prediction),
torch.tensor(1.)), obs=data)
my_data = torch.ones((100, 10000, 1))
check_model(full_model, my_data)
optim = pyro.optim.Adam({'lr': 0.5, 'lrd': 0.01, 'betas': (0.80, 0.99)})
pyro.clear_param_store()
full_guide = AutoDelta(poutine.block(full_model, hide=["assignment"]))
loss_func = pyro.infer.TraceEnum_ELBO(max_plate_nesting=3)
svi = SVI(full_model, full_guide, optim, loss=loss_func)
svi.loss(full_model, full_guide, my_data)
The error I receive is:
ValueError: Error while packing tensors at site 'A':
Invalid tensor shape.
Allowed dims: -3
Actual shape: (100, 1, 5)
Try adding shape assertions for your model's sample values and distribution parameters.
Trace Shapes:
Param Sites:
AutoDelta.A 100 1 5
AutoDelta.probs 100 1 1 2
AutoDelta.B 10000 1
Sample Sites:
A dist 100 1 5 |
value 100 1 5 |
log_prob 100 1 5 |
There seems to be a problem with the shape of A, I am not even sure whether parallel enumeration is possible in this case.
It might be distantly related to the following forum question, although their solution has not worked for me:
I would be greatly appreciative of any suggestions you might be able to provide! Thank you in advance!
Alex
P.S. Very impressed with pyro, great work!