I’m trying to train a model, when I did the training with the AutoNormal, everything went as expected. Since I will have to scale this model for larger data sizes, I need to use minibatches and this is apparently easier a custom guide, so I am trying to reproduce the same behaviour I had with the AutoNormal, but this is not working. I already tried to play around with the parameters initialisation but it didn’t change anything.
def model(data):
D = data.shape[0]
d = data.shape[1]
W_G_loc = torch.zeros((3, 2))
W_G_scale = torch.full((3, 2), 1.)
with pyro.plate("d", size = d):
W_G = pyro.sample("W_G", dist.Normal(W_G_loc), torch.tensor(W_G_scale)).to_event(2) )
with pyro.plate("D", size = D):
z = pyro.sample("z", dist.Normal(torch.tensor([0., 0.]), torch.tensor([1., 1.])).to_event(1) )
W_Gxz = torch.permute(torch.matmul(W_G, z.T), (2,0,1))
pyro.sample("obs", dist.Categorical(logits=W_Gxz.to(device)).to_event(1), obs=data)
def guide(data):
D = data.shape[0]
d = data.shape[1]
W_G_loc = pyro.param("W_G_loc", torch.zeros(d, 3, 2))
W_G_scale = pyro.param("W_G_scale", torch.full((d, 3, 2), 0.1), constraint=constraints.softplus_positive)
with pyro.plate("d", size = d):
pyro.sample("W_G", dist.Normal(W_G_loc, W_G_scale).to_event(2) )
z_loc = pyro.param("z_loc", torch.tensor([0. ,0. ]))
z_scale = pyro.param("z_scale", torch.tensor([0.1,0.1]), constraint=constraints.softplus_positive)
with pyro.plate("D", size = D):
pyro.sample("z", dist.Normal(z_loc, z_scale).to_event(1) )