Problem Extracting Marginals from SVI object


I am trying to extract the marginal distributions of sample sites in my model from an SVI object using svi.marginal(). However, whenever I try to do so, I get the following error:

     Traceback (most recent call last):
    print(svi.marginal(sites=[f"z_{conv_idx}_1", "mu_0", "mu)"]))
  File "/u/nlp/anaconda/main/anaconda3/envs/py37-an/lib/python3.7/site-packages/pyro/infer/", line 187, in marginal
    return Marginals(self, sites)
  File "/u/nlp/anaconda/main/anaconda3/envs/py37-an/lib/python3.7/site-packages/pyro/infer/", line 124, in __init__
    self._populate_traces(trace_posterior, validate_args)
  File "/u/nlp/anaconda/main/anaconda3/envs/py37-an/lib/python3.7/site-packages/pyro/infer/", line 128, in _populate_traces
    for site in self.sites}
  File "/u/nlp/anaconda/main/anaconda3/envs/py37-an/lib/python3.7/site-packages/pyro/infer/", line 128, in <dictcomp>
    for site in self.sites}
  File "/u/nlp/anaconda/main/anaconda3/envs/py37-an/lib/python3.7/site-packages/pyro/infer/", line 42, in __init__
    samples, weights = self._get_samples_and_weights()
  File "/u/nlp/anaconda/main/anaconda3/envs/py37-an/lib/python3.7/site-packages/pyro/infer/", line 62, in _get_samples_and_weights
    return torch.stack(samples_by_chain, dim=0), torch.stack(weights_by_chain, dim=0)
RuntimeError: stack expects a non-empty TensorList

Here is my Model:

class ConvModel:
     def __init__(
      self.num_acts = num_acts
      self.batch_size = batch_size
      self.vocab_len = vocab_len
      self.num_sources = 3 = data
      self.trace = TraceEnum_ELBO(max_plate_nesting=1) = AutoDelta(poutine.block(self.model, expose=['mu_0', 'mu', 'gamma', 'psi', 'theta', 'pi']))

    def model(self, minibatch_idx):
        K, T, W = # K is num conversations, T is num statements, W is num words in statement

        mu_0 = pyro.sample("mu_0", dist.Dirichlet(torch.ones(self.num_acts)))
        mu = pyro.sample("mu", dist.Dirichlet(torch.ones(self.num_acts, self.num_acts)).to_event(1))
        gamma = pyro.sample("gamma", dist.Dirichlet(torch.ones(self.num_acts, self.vocab_len)).to_event(1))

        with pyro.plate("conversation_params", K):
            theta = pyro.sample("theta", dist.Dirichlet(torch.ones(K,self.vocab_len)))
            pi = pyro.sample("pi", dist.Dirichlet(torch.ones(K,self.num_sources)))

        psi = pyro.sample("psi", dist.Dirichlet(torch.ones(self.vocab_len)))

        with pyro.plate("k", K, subsample=minibatch_idx, dim=-1) as k:
            for t in pyro.markov(range(T)):
                param = mu[z_last] if t > 0 else mu_0.expand(len(minibatch_idx), self.num_acts)
                z = pyro.sample(f"z_{k}_{t}", dist.Categorical(param), infer={'enumerate':'parallel'})

                gammas = gamma[z]
                thetas = theta[k]
                combined_shape = (gammas + thetas).shape # (enum_dims, batch_dim, vocab_len)
                gammas = gammas.expand(combined_shape)
                thetas = thetas.expand_as(gammas)
                psis = psi.expand_as(gammas)

                stacked = torch.stack([gammas, thetas, psis], dim=-2) # (enum_dims, batch_dim, 3, vocab_len)
                combined_prob = pi[k].unsqueeze(-1).expand_as(stacked) * stacked
                combined_prob = torch.sum(combined_prob, dim=-2) # (enum_dims, batch_dim, vocab_len)
                pyro.sample(f"x_{k}_{t}", dist.Normal(combined_prob, .0001).to_event(1),[k, t])

And here is the code I am using to extract the marginals:

def optimize_dataloader(
    model, guide, trace, data_loader, learning_rate=1.0e-3,
    batch_size=5, num_epochs=1000, verbosity=50,

    optimizer = pyro.optim.Adam({"lr": learning_rate},)
    svi = pyro.infer.SVI(model, guide, optimizer,
    losses = []

    for epoch in range(num_epochs):
        for idx, batch in enumerate(data_loader):
            conv_idx = batch[0]
            batch_data = batch[1]
            loss = svi.step(conv_idx)
            num_observations = batch_data.shape[0] * batch_data.shape[1] * batch_data.shape[2]

            if epoch != 0:
               print(svi.marginal(sites=[f"z_{conv_idx}_1", "mu_0", "mu"]))

I have also tried using Importance sampling together with the Marginals function and faced the same error as posted above:

marginal_approx_dist = pyro.infer.Importance(model(conv_idx), guide, num_samples=100)
marginals = pyro.infer.abstract_infer.Marginals(marginal_approx_dist)

Any advice would be greatly appreciated it!