Thank you for the quick reply! I realize that in creating my toy model above, I made a mistake in communicating my problem. The dimensions of plate1
should actually match that of variable A
, which is not an event dimension. Both should be of dimension size 5. I’ve corrected this in the original post.
With that, my original problem still is unsolved, as I don’t think I want to send A to_event. I am able to get the code running for the toy by using to_event(1) and then indexing later on, but I don’t think this is necessarily good practice, or whether it will translate to complex models?
@config_enumerate
def full_model(data):
plate1 = pyro.plate("plate1", 5, dim=-1)
plate2 = pyro.plate("plate2", data.shape[-2], dim=-2)
plate3 = pyro.plate("plate3", data.shape[-3], dim=-3)
with plate3:
A = pyro.sample("A", dist.Normal(torch.tensor([0., 0.2, 0.4, 0.6, 0.8]),
torch.tensor([1., 0.8, 0.6, 0.4, 0.2])).to_event(1))
probs = pyro.sample('probs', dist.Dirichlet(torch.tensor([30., 70.])))
assignment = pyro.sample('assignment', dist.Categorical(probs), infer={"enumerate": "parallel"})
with plate2:
B = pyro.sample("B", dist.Uniform(0., 1.))
prediction = torch.where(assignment==1,
(A[:, :, 0, :]*B).sum(-1).unsqueeze(-1),
(A[:, :, 0, :]*-B).sum(-1).unsqueeze(-1))
with plate3, plate2:
X = pyro.sample("X", dist.Gamma(torch.exp(prediction),
torch.tensor(1.)), obs=data)
Thank you for your help and patience!