Binarization of dataset for VAE

Hi everyone,

Looking at the tutorial on VAEs (Variational Autoencoders β€” Pyro Tutorials 1.8.4 documentation) it appears to me strange why the data is not binarized, even though the calculation of the likelihood of the image is supposed to be with respect to generation by a bernoulli distribution.

# define the model p(x|z)p(z)
def model(self, x):
    # register PyTorch module `decoder` with Pyro
    pyro.module("decoder", self.decoder)
    with pyro.iarange("data", x.size(0)):
        # setup hyperparameters for prior p(z)
        z_loc = x.new_zeros(torch.Size((x.size(0), self.z_dim)))
        z_scale = x.new_ones(torch.Size((x.size(0), self.z_dim)))
        # sample from prior (value will be sampled by guide when computing the ELBO)
        z = pyro.sample("latent", dist.Normal(z_loc, z_scale).independent(1))
        # decode the latent code z
        loc_img = self.decoder.forward(z)
        # score against actual images
        pyro.sample("obs", dist.Bernoulli(loc_img).independent(1), obs=x.reshape(-1, 784))

How is the likelihood calculated when the observation is not 0 or 1?, what is it actually calculating?.

I also noticed that results are better at digit generation when binarization is not done, even though the elbo is better when doing binarization.

You could take a look at #529 for some context on this. It works out because of the way Bernoulli.log_prob is implemented which computes binary_cross_entropy, and hence can be passed a continuous value. A more elegant approach would be to use a distribution valued observation as discussed in #988.

1 Like

I see, due to the bce under the bernoulli it works under the hood as a regular implementation in pytorch.

About the distribution valued observation, how would that work?

About the distribution valued observation, how would that work?

see the discussion in the issue neeraj linked above, it’s a wip idea. feel free to contribute to the discussion if you have ideas!

You could alternatively explicitly marginalize out the binarization process, which I believe is eqivalent to the continuous-observation trick in the tutorial:

with pyro.iarange("pixels", 784):
    binarized = pyro.sample("binarized",
                            dist.Bernoulli(x.reshape(-1, 784),
                            infer={'enumerate': 'parallel'})
    pyro.sample("obs", dist.Bernoulli(loc_img), obs=binarized)

svi = SVI(model, guide, optim, loss=TraceEnum_ELBO(max_iarange_nesting=2))

Interesting, so internally pyro traces the binarization, which would be different from having binary input data i guess.

Why do we need the max_iarange_nesting=2 inside the loss function definition?