Hi @srikanthram, I don’t think you can use enumeration in your model, because your observation depends on a joint function of the categorical variables. This isn’t so much a Pyro limitation as a limitation of mathematics: to enumerate over the categorical variables, you’d really need to enumerate over the cartesian product of size num_categories ** 800 etc, and that space so large that computation is infeasible. Instead of enumeration, I think you’ll want to use variational inference and learn the categorical distributions in your guide.
@fritzo This would mean to use Trace_ELBO() rather than TraceEnum_ELBO()? I did try that out by surprisingly my loss is increasing instead of decreasing. I strongly feel my pyro.params() are not getting updated during each iteration
Yes, you would use Trace_ELBO or better TraceGraph_ELBO. Training discrete models is hard, and you may need to use a low learning rate or provide baselines. You could also try using a relaxed categorical distribution, but I don’t have much experience with relaxed distributions.