ELBO Loss Computation for Different Event-Batch Shapes of Same Sample

Hi Dear Pyro Community,

I have a DMM architecture similar to the tutorial https://pyro.ai/examples/dmm.html . One main difference is that my observations are one-hot encoded at each time step. Therefore, inside the model function of DMM, I changed
pyro.sample("obs_x_%d" % t, dist.Bernoulli(emission_probs_t) .mask(mini_batch_mask[:, t - 1:t]) .to_event(1), obs=mini_batch[:, t - 1, :])
to
pyro.sample("obs_x_%d" % t, dist.OneHotCategorical(emission_probs_t) .mask(mini_batch_mask[:, t - 1]) .to_event(1), obs=mini_batch[:, t - 1, :]) and it seemed to work great.

However, then I realized that .to_event(1) was not needed for OneHotCategorical because the last dimension of emission_probs_t is already taken as event_shape without using .to_event(1) (unlike Bernoulli). When I fix this mistake my model performed worse, therefore I am trying to understand how ELBO is computed and what changed.

In order to be clear about what I mean by model performing worse, I would like to provide some statistics about 2 cases (all the loss values are normalized by N_train_time_slices as in the original code):
With .to_event(1): Training loss started from 250 and reduced to 2.5 (and similar for validation set). Then I also checked how well I can reconstruct the observations in test set (I used pyro.infer.Predictive to get hidden states, then used emitter to get emission probs) and I could reconstruct with 100% accuracy!

  • For dist.OneHotCategorical: batch_shape=[], event_shape=[len(batch), num_categories]

Without .to_event(1): Training loss started from 2.2 (it seems to be proportional to my batch_size (why?), which is 105) and reduced to 0.9 (and similar for validation set). This time I could reconstruct my observations in test set only with 85% accuracy.

  • For dist.OneHotCategorical: batch_shape=[len(batch)], event_shape=[num_categories]

My intuition says that I shouldn’t use .to_event(1) in this case, and I have no idea why the model seems to work better with .to_event(1) (even though it has higher loss).