Estimating log-likelihood in VAE model with importance sampling

Suppose I have a VAE model:

`class VAE(nn.Module):`

``````def __init__(self, input_shape, latent_dim):
super(type(self), self).__init__()

self.latent_dim = latent_dim

self.encoder = EncoderConv(input_shape, latent_dim)
self.decoder = DecoderConv(latent_dim, input_shape)

def model(self, X):
pyro.module('decoder', self.decoder)
with pyro.iarange('data', X.shape[0]):
Z_base_mean = X.new_zeros((X.shape[0], self.latent_dim))
Z_base_std = X.new_ones((X.shape[0], self.latent_dim))

base_dist = dist.Normal(Z_base_mean, Z_base_std).independent(1)
Z = pyro.sample('latent', base_dist)

X_obs_logits = self.decoder(Z).view(X.shape[0], -1)
pyro.sample(
'observation', dist.Bernoulli(logits=X_obs_logits).independent(1), obs=X.view(X.shape[0], -1)
)

def guide(self, X):
pyro.module('encoder', self.encoder)
with pyro.iarange('data', X.shape[0]):
Z_base_mean, Z_base_std = self.encoder(X)

base_dist = dist.Normal(Z_base_mean, Z_base_std).independent(1)
Z = pyro.sample('latent', base_dist)
``````

Suppose I have an instance of this model called `vae` trained with SVI; how to estimate the log-likelihood of a given batch of points X?

I have tried the following: first, I add the following method into my model:

``````def reconstruction_model(self, X):
pyro.module('encoder', self.encoder)
pyro.module('decoder', self.decoder)
with pyro.iarange('data', X.shape[0]):
Z_base_mean, Z_base_std = self.encoder(X)

base_dist = dist.Normal(Z_base_mean, Z_base_std).independent(1)
Z = pyro.sample('latent', base_dist)

X_rec_logits = self.decoder(Z)
return pyro.sample(
'reconstruction', dist.Bernoulli(logits=X_rec_logits).independent(1)
)
``````

``````importance = pyro.infer.Importance(vae.reconstruction_model).run(X)
marginal = pyro.infer.EmpiricalMarginal(importance, sites='reconstruction')
``````

This line outputs a reasonable number:
`marginal.log_prob(marginal())`

However this line throws an error:
`marginal.log_prob(X)`

``````RuntimeError                              Traceback (most recent call last)
<ipython-input-158-0c8cf1717932> in <module>()
----> 1 marginal.log_prob(X)

~/anaconda3/lib/python3.6/site-packages/pyro/distributions/empirical.py in log_prob(self, value)
127             return self._log_weights.new_zeros(torch.Size()).log()
--> 129         log_probs = self._categorical.log_prob(idxs)
130         return log_sum_exp(log_probs)
131

~/anaconda3/lib/python3.6/site-packages/torch/distributions/categorical.py in log_prob(self, value)
98         value = value.expand(value_shape)
99         log_pmf = self.logits.expand(param_shape)
--> 100         return log_pmf.gather(-1, value.unsqueeze(-1).long()).squeeze(-1)
101
102     def entropy(self):

RuntimeError: cannot unsqueeze empty tensor
``````

My questions is: what am I doing wrong? How to actually compute log-likelihood of a batch X in my model?

Hi, you can’t use `EmpiricalMarginal.log_prob` that way - as the name suggests, it’s an empirical distribution of weighted samples and will place zero probability mass on any value outside those samples. We should make that more clear in the tutorial.

To get an `N`-sample importance sampling estimate of the marginal likelihood manually:

``````import pyro.poutine as poutine
from pyro.dist.util import logsumexp

vae = VAE(...)

log_weights = []
for i in range(N):
guide_trace = poutine.trace(vae.guide).get_trace(X)
model_trace = poutine.trace(poutine.replay(vae.model, trace=guide_trace)).get_trace(X)
log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum())

log_z = logsumexp(torch.tensor(log_weights) - torch.log(torch.tensor(float(N)))
``````

How to compute log-likelihood when I run bayesian linear regression with MCMC (pyro exmaple)