Hi!
I am trying to extract the marginal distributions of sample sites in my model from an SVI object using svi.marginal(). However, whenever I try to do so, I get the following error:
Traceback (most recent call last):
print(svi.marginal(sites=[f"z_{conv_idx}_1", "mu_0", "mu)"]))
File "/u/nlp/anaconda/main/anaconda3/envs/py37-an/lib/python3.7/site-packages/pyro/infer/abstract_infer.py", line 187, in marginal
return Marginals(self, sites)
File "/u/nlp/anaconda/main/anaconda3/envs/py37-an/lib/python3.7/site-packages/pyro/infer/abstract_infer.py", line 124, in __init__
self._populate_traces(trace_posterior, validate_args)
File "/u/nlp/anaconda/main/anaconda3/envs/py37-an/lib/python3.7/site-packages/pyro/infer/abstract_infer.py", line 128, in _populate_traces
for site in self.sites}
File "/u/nlp/anaconda/main/anaconda3/envs/py37-an/lib/python3.7/site-packages/pyro/infer/abstract_infer.py", line 128, in <dictcomp>
for site in self.sites}
File "/u/nlp/anaconda/main/anaconda3/envs/py37-an/lib/python3.7/site-packages/pyro/infer/abstract_infer.py", line 42, in __init__
samples, weights = self._get_samples_and_weights()
File "/u/nlp/anaconda/main/anaconda3/envs/py37-an/lib/python3.7/site-packages/pyro/infer/abstract_infer.py", line 62, in _get_samples_and_weights
return torch.stack(samples_by_chain, dim=0), torch.stack(weights_by_chain, dim=0)
RuntimeError: stack expects a non-empty TensorList
Here is my Model:
class ConvModel:
def __init__(
self,
num_acts,
batch_size,
vocab_len,
data,
params=None
):
self.num_acts = num_acts
self.batch_size = batch_size
self.vocab_len = vocab_len
self.num_sources = 3
self.data = data
self.trace = TraceEnum_ELBO(max_plate_nesting=1)
self.guide = AutoDelta(poutine.block(self.model, expose=['mu_0', 'mu', 'gamma', 'psi', 'theta', 'pi']))
def model(self, minibatch_idx):
K, T, W = self.data.shape # K is num conversations, T is num statements, W is num words in statement
mu_0 = pyro.sample("mu_0", dist.Dirichlet(torch.ones(self.num_acts)))
mu = pyro.sample("mu", dist.Dirichlet(torch.ones(self.num_acts, self.num_acts)).to_event(1))
gamma = pyro.sample("gamma", dist.Dirichlet(torch.ones(self.num_acts, self.vocab_len)).to_event(1))
with pyro.plate("conversation_params", K):
theta = pyro.sample("theta", dist.Dirichlet(torch.ones(K,self.vocab_len)))
pi = pyro.sample("pi", dist.Dirichlet(torch.ones(K,self.num_sources)))
psi = pyro.sample("psi", dist.Dirichlet(torch.ones(self.vocab_len)))
with pyro.plate("k", K, subsample=minibatch_idx, dim=-1) as k:
for t in pyro.markov(range(T)):
param = mu[z_last] if t > 0 else mu_0.expand(len(minibatch_idx), self.num_acts)
z = pyro.sample(f"z_{k}_{t}", dist.Categorical(param), infer={'enumerate':'parallel'})
gammas = gamma[z]
thetas = theta[k]
combined_shape = (gammas + thetas).shape # (enum_dims, batch_dim, vocab_len)
gammas = gammas.expand(combined_shape)
thetas = thetas.expand_as(gammas)
psis = psi.expand_as(gammas)
stacked = torch.stack([gammas, thetas, psis], dim=-2) # (enum_dims, batch_dim, 3, vocab_len)
combined_prob = pi[k].unsqueeze(-1).expand_as(stacked) * stacked
combined_prob = torch.sum(combined_prob, dim=-2) # (enum_dims, batch_dim, vocab_len)
pyro.sample(f"x_{k}_{t}", dist.Normal(combined_prob, .0001).to_event(1),obs=self.data[k, t])
And here is the code I am using to extract the marginals:
def optimize_dataloader(
model, guide, trace, data_loader, learning_rate=1.0e-3,
batch_size=5, num_epochs=1000, verbosity=50,
):
optimizer = pyro.optim.Adam({"lr": learning_rate},)
svi = pyro.infer.SVI(model, guide, optimizer,
loss=trace)
losses = []
for epoch in range(num_epochs):
for idx, batch in enumerate(data_loader):
conv_idx = batch[0]
batch_data = batch[1]
loss = svi.step(conv_idx)
num_observations = batch_data.shape[0] * batch_data.shape[1] * batch_data.shape[2]
if epoch != 0:
print(svi.marginal(sites=[f"z_{conv_idx}_1", "mu_0", "mu"]))
I have also tried using Importance sampling together with the Marginals function and faced the same error as posted above:
marginal_approx_dist = pyro.infer.Importance(model(conv_idx), guide, num_samples=100)
marginals = pyro.infer.abstract_infer.Marginals(marginal_approx_dist)
Any advice would be greatly appreciated it!