I have slight confusion regarding the usage of TraceEnum_ELBO() as the loss criteria

In order to enumerate all sample sites at once, the docs suggested the use of @config_eumerate (and is it @config_enumerate or @config_enumerate’)

Also, if I place the @config_enumerate above the model and the guide function, is it sufficient?

For example.

def model(images,labels):

def guide(images,labels):

def do_inference():

My ELBO is fluctuating very randomly and I am not sure i if I am using the loss criteria appropriately.


To be more specific here are my model and guide functions. I have enclosed the inference part as well.I strongly feel I have messed up somewhere

The prior_distribution() and categorical_variational_parameters() are just returning probability matrices for initialization


Hi @srikanthram, I don’t think you can use enumeration in your model, because your observation depends on a joint function of the categorical variables. This isn’t so much a Pyro limitation as a limitation of mathematics: to enumerate over the categorical variables, you’d really need to enumerate over the cartesian product of size num_categories ** 800 etc, and that space so large that computation is infeasible. Instead of enumeration, I think you’ll want to use variational inference and learn the categorical distributions in your guide.


@fritzo This would mean to use Trace_ELBO() rather than TraceEnum_ELBO()? I did try that out by surprisingly my loss is increasing instead of decreasing. I strongly feel my pyro.params() are not getting updated during each iteration


Yes, you would use Trace_ELBO or better TraceGraph_ELBO. Training discrete models is hard, and you may need to use a low learning rate or provide baselines. You could also try using a relaxed categorical distribution, but I don’t have much experience with relaxed distributions.