Question about the relationship between latent in model and guide in VAE example

  • What tutorial are you running?
    I am working on the VAE example.
  • What version of Pyro are you using?
    I use Pyro 1.8.0
  • Please link or paste relevant code, and steps to reproduce.

I am trying to visualize the prior p(z) in model. I added

            plt.figure()
            plt.hist(z.detach().cpu().numpy().flatten(),bins=100)
            plt.show()

right after z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) . It looks like a standard gaussian distribution, which makes sense to me.

I then updated the guide by multiple z_loc and z_scale with 100 as follows:

    def guide(self, x):
        # register PyTorch module `encoder` with Pyro
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            # use the encoder to get the parameters used to define q(z|x)
            z_loc, z_scale = self.encoder.forward(x)
            z_loc =z_loc*100
            z_scale =z_scale*100
            # sample the latent code z
            pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

After this modification, the distribution of p(z) is no longer a standard normal distribution.
I think p(z) in the model as a prior will always be a standard normal distribution. Can anyone explain why the update in the guide alters the prior p(z)?

Thank you!

1 Like

the guide is a parameterization of an approximation of the posterior distribution, which takes the data in account and is different from the prior. see e.g. this intro tutorial

Thank you! I understand the difference. My plot is inside the model function:

    def model(self, x):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        with pyro.plate("data", x.shape[0]):
            # setup hyperparameters for prior p(z)
            z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
            z_scale = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
            # sample from prior (value will be sampled by guide when computing the ELBO)
            z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            plt.figure()
            plt.hist(z.detach().cpu().numpy().flatten(),bins=100)
            plt.show()
            print('model z ',z.min(),z.max())
            # decode the latent code z
            loc_img = self.decoder.forward(z)
            # score against actual images (with relaxed Bernoulli values)
            pyro.sample(
                "obs",
                dist.Bernoulli(loc_img, validate_args=False).to_event(1),
                obs=x.reshape(-1, 784),
            )
            # return the loc so we can visualize it later
            return loc_img

I assume this code samples z from dist.Normal(z_loc, z_scale) . Why does a change in the guide function have an impact here?

i have no idea where model is being called and so i can’t judge what it’s doing. if it’s being called from SVI then model sample calls are replayed from the guide

Thank you! Yes, it is being called from SVI.

then model sample calls are replayed from the guide

Does it mean the sample z in the model is produced by the guide (i.e. q(z|x)) instead of sampling from dist.Normal(z_loc, z_scale)?

yes, the ELBO is computed as an expectation w.r.t. the guide, see e.g. here

Thank you! I forgot it is the expectation w.r.t the guide. It is totally clear now!