If I want the model to be able to exactly reproduce a new image with the same order of entries(logits) of the original input image, rather than minimize the KL-divergence under the “bag-of-words” assumption. Do I need to manually construct an auxilary mse reconstruction loss to calculate which entry it reproduces right and which is wrong?
The Semi-Supervised Variational Autoencoder
, for example. I find from the example ss_vae_M2.py
that it minimize “elbo loss” between the model
and guide
, and an auxiliary “reconstruction loss” of label “y”. Do I need to construct a recon loss of xs?
I come up with that question is because I think the neural network of decoder_x
is able to do that thing, as the num_instances
is ordered:
x_loc = torch.reshape(loc, (-1, batch_size, num_instances, num_categories))
from the snippet talked about in topic: dist.Bernoulli recognize the right batch_size, but dist.Categorical won’t - Tutorials - Pyro Discussion Forum
loc = self.decoder_x.forward(thetas)
loc = torch.reshape(loc, (-1, batch_size, num_instances, num_categories))
xs_hat = pyro.sample("x", dist.Categorical(logits=loc, validate_args=False).to_event(1), obs=xs)
And I also don’t know how to set up the mse loss for x as the classify_model() in ss_vae, maybe write a same generate process of x in mse_model() with dummy mse_guide()?
These things just don’t like those in pytorch that are simple and intuitive, and I’m confused about some of the details.