Unexpected Weight Update in SSVAE Framework

Hi Dear Pyro Community!

I have built a semi-supervised VAE network similar to https://pyro.ai/examples/ss-vae.html but with different hierarchy of variables. Unfortunately I cannot share the full code since it’s a more than 1000 lines of mess right now…

First I want to describe the network quickly: x refers to observation vector, y refers to binary label of x and z is the latent variable. In the semi-supervised framework, not all x have the corresponding y labels.
Model: I am modeling p(z), p(x|z) and p(y|z). Sampling from p(y|z) is only done when y is observed.
Guide: I am modeling q(z|x,y) and q(y) . q(y) is only used when y is not observed. TraceEnum_ELBO is used to make enumeration for y.
All those probability distributions’ parameters are controlled by different neural network modules.

At every epoch of training, I am first iterating over the batches of unlabelled data, then iterating over the labeled data.

The problem: At 1st epoch, I observe that the network weights of p(y|z) are not updated during the batches of unlabelled data and they start getting updated for the batches of labelled data (as expected). However, starting from 2nd epoch, the network weights of p(y|z) are updated during the batches of unlabelled data as well, even though I am 100% sure that no sampling occurs for p(y|z) at that time. Can someone please explain how this case is possible?

I also appreciate any suggestion to avoid such problem.