How to write a guide analogous to AutoNormal

I’m trying to train a model, when I did the training with the AutoNormal, everything went as expected. Since I will have to scale this model for larger data sizes, I need to use minibatches and this is apparently easier a custom guide, so I am trying to reproduce the same behaviour I had with the AutoNormal, but this is not working. I already tried to play around with the parameters initialisation but it didn’t change anything.

def model(data):

    D = data.shape[0]
    d = data.shape[1]


    W_G_loc = torch.zeros((3, 2))
    W_G_scale = torch.full((3, 2), 1.)

    with pyro.plate("d", size = d):
        W_G = pyro.sample("W_G", dist.Normal(W_G_loc), torch.tensor(W_G_scale)).to_event(2) )

    
    with pyro.plate("D", size = D):

        z = pyro.sample("z", dist.Normal(torch.tensor([0., 0.]), torch.tensor([1., 1.])).to_event(1) )
        W_Gxz = torch.permute(torch.matmul(W_G, z.T), (2,0,1))
        pyro.sample("obs", dist.Categorical(logits=W_Gxz.to(device)).to_event(1), obs=data)

def guide(data):
    D = data.shape[0]
    d = data.shape[1]


    W_G_loc   = pyro.param("W_G_loc",   torch.zeros(d, 3, 2))
    W_G_scale = pyro.param("W_G_scale", torch.full((d, 3, 2), 0.1), constraint=constraints.softplus_positive)


    with pyro.plate("d", size = d):
        pyro.sample("W_G", dist.Normal(W_G_loc, W_G_scale).to_event(2) )

    z_loc    = pyro.param("z_loc",   torch.tensor([0. ,0. ]))
    z_scale  = pyro.param("z_scale", torch.tensor([0.1,0.1]), constraint=constraints.softplus_positive)
    
    with pyro.plate("D", size = D):
        pyro.sample("z", dist.Normal(z_loc, z_scale).to_event(1) )

1 Like

don’t know what “not working” might mean but generally scale parameters of normal distributions in guides should be initialized to small values like 0.01