 # CVAE: Learning Structured Output Representation using Deep Conditional Generative Models

I’m finishing the implementation of Learning Structured Output Representation using Deep Conditional Generative Models paper using Pyro… need a help in the final step (this will become a nice tutorial about CVAEs ):

I already implemented MNIST example, it works perfectly, and I was able to generate same images and numbers from the paper:

In the paper, the authors train a CVAE to predict the missing 3 quadrants of an MNIST image, by providing 1 quadrant as input (as shown above). As the data is in [0, 1] interval, I’m using -1 to indicate the missing 3 quadrants in the inputs, and the missing 1 quadrant in the output.

In my model, I’m removing the elements in the tensors that have -1 value, as seen in `mask_loc` and `mask_ys`:

``````def model(self, xs, ys=None):
# register this pytorch module and all of its sub-modules with pyro
pyro.module("generation_net", self)
with pyro.plate("data"):

# Prior network uses the baseline predictions as initial guess.
# This is the generative process with recurrent connection
y_hat = self.baseline_net(xs).view(xs.shape)

# sample the handwriting style from the prior distribution, which is
# modulated by the input xs.
prior_loc, prior_scale = self.prior_net(xs, y_hat)
zs = pyro.sample('z', dist.Normal(prior_loc, prior_scale).to_event(1))

# the output y is generated from the distribution pθ(y|x, z)
loc = self.generation_net(zs)
# we will only sample in the masked image
mask_loc = loc[(xs == -1).view(-1, 784)]
mask_ys = ys[xs == -1] if ys is not None else None
# return the loc so we can visualize it later
return loc
``````

Everything works great… except that, in the final step, I need to evaluate the model. When I call:

``````predictive = Predictive(pre_trained_cvae.model,
guide=pre_trained_cvae.guide,
num_samples=100)
preds = predictive(inputs, outputs)['y'].mean(dim=0)
``````

my `y`'s have 588 elements, 3/4 of a digit, as expected, because I removed 1/4 of the elements during sampling in the model. The only workaround I see is to write code to fill in the missing 1/4 after prediction, which involves loops and will make everything slow and ugly.

It would be great to, in the sampling within the model, not do the mask, so that my `y`'s would have 784 elements. I’d like to do something similar to what we do with regular NNs: predict 28 x 28, and customize the loss so that it only computes losses in the masked regions (!= -1).

However, I’m afraid that, if I just use `loc` and `ys` in the model and don’t change anything, the algorithm will take all the -1 values in the optimization process, which is an error.

Can anybody help?

Thanks!

1 Like

i’m not sure i understand exactly but can you do something like this?

``````from functools import partial

def model(flag, ...):
if flag:
# do something
else:
# do something else

model_true = partial(model, True)
model_false = partial(model, False)

svi = SVI(model_true, ...)
predictive = Predictive(model_false, ...)

``````
1 Like

Something like that worked! Thanks!

Just created a PR with the tutorial:

Thanks again @martinjankowiak!
Cheers!

2 Likes