Conditional VAE

Hi,
Im trying to build a Conditional VAE.
Given a specific digit label, i want to generate MNIST digits.
I have found an implementation in Pytorch but i want to make it in Pyro framework.
Blog implementation: https://wiseodd.github.io/techblog/2016/12/17/conditional-vae/

Does it seem enough to concatenate the X,Y (X:data, Y is one hot encoding label)
and produce latent z? So we get a guide that has a distribution Qφ(z|X,Y)

Then the model accepts z,Y and has a distribution Pθ(X|z,Y).

The Generation process should be something like:

  • sample z from prior N(0,1). *Prior here is P(z|X) not P(z) as in vanilla VAE.
  • deterministically set the label Y
  • Decode it
  • Generate digits of this specific label Y

MY intuitions says that a classifier is missing somewhere so we can add an auxiliary loss also to the main loss.

What is your opinion of how should i proceed?

Many thanks in advance.