CVAE: Learning Structured Output Representation using Deep Conditional Generative Models

I’m finishing the implementation of Learning Structured Output Representation using Deep Conditional Generative Models paper using Pyro… need a help in the final step (this will become a nice tutorial about CVAEs :slight_smile: ):

I already implemented MNIST example, it works perfectly, and I was able to generate same images and numbers from the paper:

In the paper, the authors train a CVAE to predict the missing 3 quadrants of an MNIST image, by providing 1 quadrant as input (as shown above). As the data is in [0, 1] interval, I’m using -1 to indicate the missing 3 quadrants in the inputs, and the missing 1 quadrant in the output.

In my model, I’m removing the elements in the tensors that have -1 value, as seen in mask_loc and mask_ys:

def model(self, xs, ys=None):
        # register this pytorch module and all of its sub-modules with pyro
        pyro.module("generation_net", self)
        with pyro.plate("data"):

            # Prior network uses the baseline predictions as initial guess.
            # This is the generative process with recurrent connection
            with torch.no_grad():
                y_hat = self.baseline_net(xs).view(xs.shape)

            # sample the handwriting style from the prior distribution, which is
            # modulated by the input xs.
            prior_loc, prior_scale = self.prior_net(xs, y_hat)
            zs = pyro.sample('z', dist.Normal(prior_loc, prior_scale).to_event(1))

            # the output y is generated from the distribution pθ(y|x, z)
            loc = self.generation_net(zs)
            # we will only sample in the masked image
            mask_loc = loc[(xs == -1).view(-1, 784)]
            mask_ys = ys[xs == -1] if ys is not None else None
            pyro.sample('y', dist.Bernoulli(mask_loc).to_event(1), obs=mask_ys)
            # return the loc so we can visualize it later
            return loc

Everything works great… except that, in the final step, I need to evaluate the model. When I call:

predictive = Predictive(pre_trained_cvae.model,
preds = predictive(inputs, outputs)['y'].mean(dim=0)

my y's have 588 elements, 3/4 of a digit, as expected, because I removed 1/4 of the elements during sampling in the model. The only workaround I see is to write code to fill in the missing 1/4 after prediction, which involves loops and will make everything slow and ugly.

It would be great to, in the sampling within the model, not do the mask, so that my y's would have 784 elements. I’d like to do something similar to what we do with regular NNs: predict 28 x 28, and customize the loss so that it only computes losses in the masked regions (!= -1).

However, I’m afraid that, if I just use loc and ys in the model and don’t change anything, the algorithm will take all the -1 values in the optimization process, which is an error.

Can anybody help?


1 Like

i’m not sure i understand exactly but can you do something like this?

from functools import partial

def model(flag, ...):
    if flag:
         # do something
         # do something else

model_true = partial(model, True)
model_false = partial(model, False)

svi = SVI(model_true, ...)
predictive = Predictive(model_false, ...)

1 Like

Something like that worked! Thanks!

Just created a PR with the tutorial:

Thanks again @martinjankowiak!