Simple model parameters not converging

Hi, I’m new to probabilistic programming and am trying to get a toy Dirichlet-Multinomial model working. I’m not sure why but when I run the inference the parameters either barely update or seem to do so in entirely wrong directions. Here’s my code:

data = [10.0,12.0,2.0]

def model(data = data):
    d = torch.FloatTensor(data)
    return pyro.sample("obs", dist.DirichletMultinomial(torch.ones(len(data)), sum(data)), obs = d)

def guide(data=data):
    params = pyro.param("alphas", torch.ones(len(data)),constraint=constraints.positive)
    return pyro.sample("obs", dist.DirichletMultinomial(params,sum(data)))

adam_params = {"lr": 0.01}
optimizer = Adam(adam_params)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
alphas, losses = [], []
n_steps = 5000
for step in range(n_steps):

Any advice would be hugely appreciate, thanks!!

notice that your model has no latent variables, only observed data. consequently the only “inference” there is for you to possibly do is maximum likelihood estimation of the parameters in your model. however, the parameters in your model are fixed: torch.ones(len(data)). consequently there is nothing for you to learn. in this context your guide should be empty, since, again, there are no latent variables for you to infer.

1 Like

Thanks so much for the response - makes sense and seems to be working now. Just for my own ignorance though - is the reason why only maximum likelihood is possible rather than MAP in this case just that I haven’t specified a higher-level prior over the parameters of the Dirichlet?

code that works now if anyone is stuck on smth similar:

def model(data = data):
    d = torch.FloatTensor(data)
    alphas = pyro.param("alphas", (torch.ones(len(data))))
    return pyro.sample("obs", dist.DirichletMultinomial(alphas, sum(data)), obs = d)

def guide(data = data):
    return None

yes, that’s right. if you added additional latent variables then you could do MAP, variational inference, etc.

note that the DirichletMultinomial distribution already integrates out the dirichlet latent variable exactly and so it disappears, so to speak.

1 Like