Is it possible to build reconstruction loss to learn sequential/positional relations among data entries?

If I want the model to be able to exactly reproduce a new image with the same order of entries(logits) of the original input image, rather than minimize the KL-divergence under the “bag-of-words” assumption. Do I need to manually construct an auxilary mse reconstruction loss to calculate which entry it reproduces right and which is wrong?

The Semi-Supervised Variational Autoencoder, for example. I find from the example that it minimize “elbo loss” between the model and guide, and an auxiliary “reconstruction loss” of label “y”. Do I need to construct a recon loss of xs?

I come up with that question is because I think the neural network of decoder_x is able to do that thing, as the num_instances is ordered:

x_loc = torch.reshape(loc, (-1, batch_size, num_instances, num_categories))

from the snippet talked about in topic: dist.Bernoulli recognize the right batch_size, but dist.Categorical won’t - Tutorials - Pyro Discussion Forum

loc = self.decoder_x.forward(thetas)
loc = torch.reshape(loc, (-1, batch_size, num_instances, num_categories))
xs_hat = pyro.sample("x", dist.Categorical(logits=loc, validate_args=False).to_event(1), obs=xs)

And I also don’t know how to set up the mse loss for x as the classify_model() in ss_vae, maybe write a same generate process of x in mse_model() with dummy mse_guide()?

These things just don’t like those in pytorch that are simple and intuitive, and I’m confused about some of the details.