SS-VAE unsupervised loss problems


Hi, I’m trying to adapt the SS-VAE example for my own use, and in particular I’m replacing the one-hot categorical latent “y” with an array of Bernoulli latent variables.

The problem is that the unsupervised loss (when y is not observed) is always 0. So, I imagine the model is not learning from the unsupervised examples.

After some debugging, I think the issue could be related to my Bernoulli latents, and the “enum_discrete” mode. If I enter the “step” function, and follow the debugger down to Trace_ELBO.loss_and_grads(), it’s computing “weight” values of 0 for all examples in the batch. Further down, the “scale” returned by iter_discrete_traces() is all zeros for some reason…

Any clues or suggestions would be much appreciated.

I don’t know if its related, but I’ve noticed that every time this latent is sampled in the model or guide function (using pyro.sample), it always returns the same value, for every value in the array. For example, I can call y= pyro.sample("y", dist.bernoulli, alpha_y) where alpha_y is all 0.5s, and every value of the returned y is either 0 or 1. I’d expect a mixture of 0s and 1s.



Hi Paul, I’m not sure why the unsupervised loss is always zero, but I can explain some of the behavior you’re seeing in the debugger. When you use enum_discrete=True, Pyro deterministically samples all values of your discrete latent variables, here your Bernoullis. Pyro deterministically enumerates these in lock-step, first generating a program trace with all-zeros and computing ELBO for that trace, then generating a program trace with all-ones and computing an ELBO for that trace. The ELBOs from the two traces are then combined wrt the weight of each trace (which may be element-wise).

My guess is that in the debugger you’re only seeing the first pass, and if you step or continue to the second pass (via breakpoint or something) you’ll see an all-1s sample. I guess that the weight for this second pass will be 1 and the scale will be all 1s. I’m not sure why you’re seeing such extreme weights, but it likely indicates that your Bernoulli probs or logits are being set to extreme values.

HTH, -Fritz


Hi Fritz,

Thanks for the detailed and informative response! Do you have any particular references, so I can learn how this works in more detail?

Debugging at line 51 of scale = torch.exp(log_pdf.detach()), log_pdf appeared to contain a really large value in the first element (something like 1e5), followed by sensible-looking numbers like 0.4. The value of scale after executing this line was all zeros. Maybe the first element messed up torch.exp, causing it to produce all zeros?



Hi Paul,

The best references I know of are:

If the first element of scale is much larger than the others, then the weight of that element should be 1 and the weight of all other elements should be zero. You could try watching the weight variable in Trace_ELBO._get_traces(), that should be normalized and I suspect it will be [1, 0, 0, ..., 0].


Thanks Fritz,

Another thing I’m observing is, if I increase the number of Bernoulli latents in the model, the overall unsupervised loss function for each batch becomes smaller, and eventually becomes zero.

So, if I drastically reduce the number of latents, the loss is larger. And, if I use my original number of Bernoulli latents (which was fairly large), the loss becomes zero.

Also, the unsupervised loss is always much smaller than the supervised ELBO loss. I’d expect them to be of a similar magnitude, at least after a number of epochs…

Is this behavior something you’d expect?