I’m going through the Pyro VAE tutorial here (pyro/vae.py at dev · pyro-ppl/pyro · GitHub) but I am struggling to understand how and when to use to_event()
.
Here’s a bit of code that I am finding confusing:
def model(self, x):
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)
with pyro.plate("data", x.shape[0]):
# setup hyperparameters for prior p(z)
z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
z_scale = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# decode the latent code z
loc_img = self.decoder.forward(z)
# score against actual images
pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))
# return the loc so we can visualize it later
return loc_img
My understanding of to_event()
is that it tells pyro that certain dimensions are dependent. But this doesn’t make sense to me here: z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
since in a VAE the latents are independent gaussians. I have read through the documentation on to_event but this is still confusing me. I understand that to_event
will reduce the dimension of the batch_shape
attribute but I don’t understand why that’s necessary here. Thanks!