Im trying to build a Conditional VAE.
Given a specific digit label, i want to generate MNIST digits.
I have found an implementation in Pytorch but i want to make it in Pyro framework.
Blog implementation: https://wiseodd.github.io/techblog/2016/12/17/conditional-vae/
Does it seem enough to concatenate the X,Y (X:data, Y is one hot encoding label)
and produce latent z? So we get a guide that has a distribution Qφ(z|X,Y)
Then the model accepts z,Y and has a distribution Pθ(X|z,Y).
The Generation process should be something like:
- sample z from prior N(0,1). *Prior here is P(z|X) not P(z) as in vanilla VAE.
- deterministically set the label Y
- Decode it
- Generate digits of this specific label Y
MY intuitions says that a classifier is missing somewhere so we can add an auxiliary loss also to the main loss.
What is your opinion of how should i proceed?
Many thanks in advance.