Estimating log-likelihood in VAE model with importance sampling

Suppose I have a VAE model:

class VAE(nn.Module):

def __init__(self, input_shape, latent_dim):
    super(type(self), self).__init__()

    self.latent_dim = latent_dim
    
    self.encoder = EncoderConv(input_shape, latent_dim)
    self.decoder = DecoderConv(latent_dim, input_shape)
    
def model(self, X):
    pyro.module('decoder', self.decoder)
    with pyro.iarange('data', X.shape[0]):
        Z_base_mean = X.new_zeros((X.shape[0], self.latent_dim))
        Z_base_std = X.new_ones((X.shape[0], self.latent_dim))
        
        base_dist = dist.Normal(Z_base_mean, Z_base_std).independent(1)
        Z = pyro.sample('latent', base_dist)
        
        X_obs_logits = self.decoder(Z).view(X.shape[0], -1)
        pyro.sample(
            'observation', dist.Bernoulli(logits=X_obs_logits).independent(1), obs=X.view(X.shape[0], -1)
        )
    
def guide(self, X):
    pyro.module('encoder', self.encoder)
    with pyro.iarange('data', X.shape[0]):
        Z_base_mean, Z_base_std = self.encoder(X)
        
        base_dist = dist.Normal(Z_base_mean, Z_base_std).independent(1)
        Z = pyro.sample('latent', base_dist)

Suppose I have an instance of this model called vae trained with SVI; how to estimate the log-likelihood of a given batch of points X?

I have tried the following: first, I add the following method into my model:

def reconstruction_model(self, X):
    pyro.module('encoder', self.encoder)
    pyro.module('decoder', self.decoder)
    with pyro.iarange('data', X.shape[0]):
        Z_base_mean, Z_base_std = self.encoder(X)
        
        base_dist = dist.Normal(Z_base_mean, Z_base_std).independent(1)
        Z = pyro.sample('latent', base_dist)
    
        X_rec_logits = self.decoder(Z)
        return pyro.sample(
            'reconstruction', dist.Bernoulli(logits=X_rec_logits).independent(1)
        )

Second, I follow the tutorials:

importance = pyro.infer.Importance(vae.reconstruction_model).run(X)
marginal = pyro.infer.EmpiricalMarginal(importance, sites='reconstruction')

This line outputs a reasonable number:
marginal.log_prob(marginal())

However this line throws an error:
marginal.log_prob(X)

RuntimeError                              Traceback (most recent call last)
<ipython-input-158-0c8cf1717932> in <module>()
----> 1 marginal.log_prob(X)

~/anaconda3/lib/python3.6/site-packages/pyro/distributions/empirical.py in log_prob(self, value)
    127             return self._log_weights.new_zeros(torch.Size()).log()
    128         idxs = torch.arange(self.sample_size)[selection_mask.min(dim=-1)[0]]
--> 129         log_probs = self._categorical.log_prob(idxs)
    130         return log_sum_exp(log_probs)
    131 

~/anaconda3/lib/python3.6/site-packages/torch/distributions/categorical.py in log_prob(self, value)
     98         value = value.expand(value_shape)
     99         log_pmf = self.logits.expand(param_shape)
--> 100         return log_pmf.gather(-1, value.unsqueeze(-1).long()).squeeze(-1)
    101 
    102     def entropy(self):

RuntimeError: cannot unsqueeze empty tensor

My questions is: what am I doing wrong? How to actually compute log-likelihood of a batch X in my model?

Hi, you can’t use EmpiricalMarginal.log_prob that way - as the name suggests, it’s an empirical distribution of weighted samples and will place zero probability mass on any value outside those samples. We should make that more clear in the tutorial.

To get an N-sample importance sampling estimate of the marginal likelihood manually:

import pyro.poutine as poutine
from pyro.dist.util import logsumexp

vae = VAE(...)

log_weights = []
for i in range(N):
    guide_trace = poutine.trace(vae.guide).get_trace(X)
    model_trace = poutine.trace(poutine.replay(vae.model, trace=guide_trace)).get_trace(X)
    log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum())

log_z = logsumexp(torch.tensor(log_weights) - torch.log(torch.tensor(float(N)))

How to compute log-likelihood when I run bayesian linear regression with MCMC (pyro exmaple)