I’m finishing the implementation of Learning Structured Output Representation using Deep Conditional Generative Models paper using Pyro… need a help in the final step (this will become a nice tutorial about CVAEs ):
I already implemented MNIST example, it works perfectly, and I was able to generate same images and numbers from the paper:
In the paper, the authors train a CVAE to predict the missing 3 quadrants of an MNIST image, by providing 1 quadrant as input (as shown above). As the data is in [0, 1] interval, I’m using -1 to indicate the missing 3 quadrants in the inputs, and the missing 1 quadrant in the output.
In my model, I’m removing the elements in the tensors that have -1 value, as seen in mask_loc
and mask_ys
:
def model(self, xs, ys=None):
# register this pytorch module and all of its sub-modules with pyro
pyro.module("generation_net", self)
with pyro.plate("data"):
# Prior network uses the baseline predictions as initial guess.
# This is the generative process with recurrent connection
with torch.no_grad():
y_hat = self.baseline_net(xs).view(xs.shape)
# sample the handwriting style from the prior distribution, which is
# modulated by the input xs.
prior_loc, prior_scale = self.prior_net(xs, y_hat)
zs = pyro.sample('z', dist.Normal(prior_loc, prior_scale).to_event(1))
# the output y is generated from the distribution pθ(y|x, z)
loc = self.generation_net(zs)
# we will only sample in the masked image
mask_loc = loc[(xs == -1).view(-1, 784)]
mask_ys = ys[xs == -1] if ys is not None else None
pyro.sample('y', dist.Bernoulli(mask_loc).to_event(1), obs=mask_ys)
# return the loc so we can visualize it later
return loc
Everything works great… except that, in the final step, I need to evaluate the model. When I call:
predictive = Predictive(pre_trained_cvae.model,
guide=pre_trained_cvae.guide,
num_samples=100)
preds = predictive(inputs, outputs)['y'].mean(dim=0)
my y
’s have 588 elements, 3/4 of a digit, as expected, because I removed 1/4 of the elements during sampling in the model. The only workaround I see is to write code to fill in the missing 1/4 after prediction, which involves loops and will make everything slow and ugly.
It would be great to, in the sampling within the model, not do the mask, so that my y
’s would have 784 elements. I’d like to do something similar to what we do with regular NNs: predict 28 x 28, and customize the loss so that it only computes losses in the masked regions (!= -1).
However, I’m afraid that, if I just use loc
and ys
in the model and don’t change anything, the algorithm will take all the -1 values in the optimization process, which is an error.
Can anybody help?
Thanks!