I’m finishing the implementation of Learning Structured Output Representation using Deep Conditional Generative Models paper using Pyro… need a help in the final step (this will become a nice tutorial about CVAEs ):
I already implemented MNIST example, it works perfectly, and I was able to generate same images and numbers from the paper:
In the paper, the authors train a CVAE to predict the missing 3 quadrants of an MNIST image, by providing 1 quadrant as input (as shown above). As the data is in [0, 1] interval, I’m using -1 to indicate the missing 3 quadrants in the inputs, and the missing 1 quadrant in the output.
In my model, I’m removing the elements in the tensors that have -1 value, as seen in
def model(self, xs, ys=None): # register this pytorch module and all of its sub-modules with pyro pyro.module("generation_net", self) with pyro.plate("data"): # Prior network uses the baseline predictions as initial guess. # This is the generative process with recurrent connection with torch.no_grad(): y_hat = self.baseline_net(xs).view(xs.shape) # sample the handwriting style from the prior distribution, which is # modulated by the input xs. prior_loc, prior_scale = self.prior_net(xs, y_hat) zs = pyro.sample('z', dist.Normal(prior_loc, prior_scale).to_event(1)) # the output y is generated from the distribution pθ(y|x, z) loc = self.generation_net(zs) # we will only sample in the masked image mask_loc = loc[(xs == -1).view(-1, 784)] mask_ys = ys[xs == -1] if ys is not None else None pyro.sample('y', dist.Bernoulli(mask_loc).to_event(1), obs=mask_ys) # return the loc so we can visualize it later return loc
Everything works great… except that, in the final step, I need to evaluate the model. When I call:
predictive = Predictive(pre_trained_cvae.model, guide=pre_trained_cvae.guide, num_samples=100) preds = predictive(inputs, outputs)['y'].mean(dim=0)
y's have 588 elements, 3/4 of a digit, as expected, because I removed 1/4 of the elements during sampling in the model. The only workaround I see is to write code to fill in the missing 1/4 after prediction, which involves loops and will make everything slow and ugly.
It would be great to, in the sampling within the model, not do the mask, so that my
y's would have 784 elements. I’d like to do something similar to what we do with regular NNs: predict 28 x 28, and customize the loss so that it only computes losses in the masked regions (!= -1).
However, I’m afraid that, if I just use
ys in the model and don’t change anything, the algorithm will take all the -1 values in the optimization process, which is an error.
Can anybody help?