Forward pass in VAE tutorial

Could you please explain in brief the forward pass in VAE?
Is it like the following steps?

  1. Guide gets input:X and produces: z latent dimension

  2. Sample from z latent

  3. Model samples from prior Z-N(0,1) in order to check KL divergence between approximate posterior Q(z|X) and prior P(z).

  4. Model decodes the z latent (Guide output) , then we check the reconstruction.

  5. backprop after that

Last thing…
Reparameterization trick is taken care internally in SVI, right?
Thanks!

define the guide (i.e. variational distribution) q(z|x)

def guide(self, x):
    # register PyTorch module `encoder` with Pyro
    pyro.module("encoder", self.encoder)
    with pyro.plate("data", x.shape[0]):
        # use the encoder to get the parameters used to define q(z|x)
        z_loc, z_scale = self.encoder.forward(x)
        # sample the latent code z
        pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

define the model p(x|z)p(z)

def model(self, x):
    # register PyTorch module `decoder` with Pyro
    pyro.module("decoder", self.decoder)
    with pyro.plate("data", x.shape[0]):
        # setup hyperparameters for prior p(z)
        z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
        z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
        # sample from prior (value will be sampled by guide when computing the ELBO)
        z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
        # decode the latent code z
        loc_img = self.decoder.forward(z)
        # score against actual images
        pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))

more or less. if you haven’t yet i recommend reading the the tutorial and/or the vae paper. overall, the same thing is being computed as the other svi examples except in this case, the model and guide are neural networks. so under the hood, pyro is still computing the elbo by drawing samples from the guide and evaluating the model likelihood.

Reparameterization trick is taken care internally in SVI, right?

yes pyro automatically does the reparameterization trick for reparameterizable distributions.

1 Like

thank you! one last thing:
When we score observations ‘obs’ with the actual data-point (image here), what is the loss function that Pyro uses (checking how well the reconstruction is done)?
For example:

        # score against actual images
        pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))

In case i have another observation that is one hot encoding label and i want to score with the initial label. Will it automatically use cross entropy?

ELBO consists of 2 parts. Reconstruction term and KL(Aproximate posterior || prior). What about the reconstruction term?

scoring against the Bernoulli acts as a reconstruction loss, incorporated as the model likelihood. if you have observations from a set of classes, you can consider using a Categorical distribution.

1 Like