Suppose I have a VAE model:
class VAE(nn.Module):
def __init__(self, input_shape, latent_dim):
super(type(self), self).__init__()
self.latent_dim = latent_dim
self.encoder = EncoderConv(input_shape, latent_dim)
self.decoder = DecoderConv(latent_dim, input_shape)
def model(self, X):
pyro.module('decoder', self.decoder)
with pyro.iarange('data', X.shape[0]):
Z_base_mean = X.new_zeros((X.shape[0], self.latent_dim))
Z_base_std = X.new_ones((X.shape[0], self.latent_dim))
base_dist = dist.Normal(Z_base_mean, Z_base_std).independent(1)
Z = pyro.sample('latent', base_dist)
X_obs_logits = self.decoder(Z).view(X.shape[0], -1)
pyro.sample(
'observation', dist.Bernoulli(logits=X_obs_logits).independent(1), obs=X.view(X.shape[0], -1)
)
def guide(self, X):
pyro.module('encoder', self.encoder)
with pyro.iarange('data', X.shape[0]):
Z_base_mean, Z_base_std = self.encoder(X)
base_dist = dist.Normal(Z_base_mean, Z_base_std).independent(1)
Z = pyro.sample('latent', base_dist)
Suppose I have an instance of this model called vae
trained with SVI; how to estimate the log-likelihood of a given batch of points X?
I have tried the following: first, I add the following method into my model:
def reconstruction_model(self, X):
pyro.module('encoder', self.encoder)
pyro.module('decoder', self.decoder)
with pyro.iarange('data', X.shape[0]):
Z_base_mean, Z_base_std = self.encoder(X)
base_dist = dist.Normal(Z_base_mean, Z_base_std).independent(1)
Z = pyro.sample('latent', base_dist)
X_rec_logits = self.decoder(Z)
return pyro.sample(
'reconstruction', dist.Bernoulli(logits=X_rec_logits).independent(1)
)
Second, I follow the tutorials:
importance = pyro.infer.Importance(vae.reconstruction_model).run(X)
marginal = pyro.infer.EmpiricalMarginal(importance, sites='reconstruction')
This line outputs a reasonable number:
marginal.log_prob(marginal())
However this line throws an error:
marginal.log_prob(X)
RuntimeError Traceback (most recent call last)
<ipython-input-158-0c8cf1717932> in <module>()
----> 1 marginal.log_prob(X)
~/anaconda3/lib/python3.6/site-packages/pyro/distributions/empirical.py in log_prob(self, value)
127 return self._log_weights.new_zeros(torch.Size()).log()
128 idxs = torch.arange(self.sample_size)[selection_mask.min(dim=-1)[0]]
--> 129 log_probs = self._categorical.log_prob(idxs)
130 return log_sum_exp(log_probs)
131
~/anaconda3/lib/python3.6/site-packages/torch/distributions/categorical.py in log_prob(self, value)
98 value = value.expand(value_shape)
99 log_pmf = self.logits.expand(param_shape)
--> 100 return log_pmf.gather(-1, value.unsqueeze(-1).long()).squeeze(-1)
101
102 def entropy(self):
RuntimeError: cannot unsqueeze empty tensor
My questions is: what am I doing wrong? How to actually compute log-likelihood of a batch X in my model?