I’m going through the Pyro VAE tutorial here (https://github.com/pyro-ppl/pyro/blob/dev/examples/vae/vae.py) but I am struggling to understand how and when to use
Here’s a bit of code that I am finding confusing:
def model(self, x): # register PyTorch module `decoder` with Pyro pyro.module("decoder", self.decoder) with pyro.plate("data", x.shape): # setup hyperparameters for prior p(z) z_loc = torch.zeros(x.shape, self.z_dim, dtype=x.dtype, device=x.device) z_scale = torch.ones(x.shape, self.z_dim, dtype=x.dtype, device=x.device) # sample from prior (value will be sampled by guide when computing the ELBO) z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) # decode the latent code z loc_img = self.decoder.forward(z) # score against actual images pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784)) # return the loc so we can visualize it later return loc_img
My understanding of
to_event() is that it tells pyro that certain dimensions are dependent. But this doesn’t make sense to me here:
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) since in a VAE the latents are independent gaussians. I have read through the documentation on to_event but this is still confusing me. I understand that
to_event will reduce the dimension of the
batch_shape attribute but I don’t understand why that’s necessary here. Thanks!