Is ever-decreasing ELBO a good sign?

I have currently implemented a model where the guide is empty (I’m guessing this is MAP estimation). My loss starts around O(10^4) and keeps falling below O(-10^4) over time.

Is this expected? I understand we are now in the likelihood domain and not metrics like MSE, but this just doesn’t seem to plateau. Even the test loss follows a similar pattern so I’m assuming the parameters are learning something about the distribution.

Empty guide means that there is no latent site in your model and you are doing Maximum Likelihood. About the phenomenon you got, I think that there might be something wrong with distribution args/support while training. Maybe turning on pyro validation can help you identify the problem.

How exactly should I enable Pyro validation?

EDIT: Found it - pyro.enable_validation(True).

For reference, here is what my model looks like

  def model(self, L, X, pheno):
    """
    :param L: Length of each Time Series in batch, B
    :param X: Time Series, T_max x B x D
    :param pheno: Binary Multi-label class targets, B x 25
    """
    pyro.module('dmm', self, update_module_params=True)

    B = pheno.size(0)
    T_max = X.size(0)

    df = pyro.param('df', torch.tensor(2.0, device=X.device),
                    constraint=constraints.positive)
    mu_0 = pyro.param('mu_0', torch.zeros(1, self.latent_size,
                                          device=X.device))
    sigma_0 = pyro.param('sigma_0', torch.zeros(1, self.latent_size,
                                                device=X.device))
    mu_0 = mu_0.expand(B, -1)
    sigma_0 = sigma_0.expand(B, -1)

    z_pheno = torch.zeros(B, self.latent_size).to(X.device)

    h_prev = self.h_0.expand(B, -1).to(X.device)
    with pyro.plate('X_plate', B, device=X.device):
      z_prev = pyro.sample('Z_0', dist.Normal(mu_0, sigma_0.exp()).to_event(1))
      for t in range(T_max):
        with poutine.mask(mask=(t < L)):
          h_t, z_mu, z_log_var = self.transition(z_prev, h_prev)
          z_dist = dist.Normal(z_mu, z_log_var.exp()).to_event(1)

          z_t = pyro.sample('Z_{}'.format(t + 1), z_dist)

          x_mu, x_log_var = self.emitter(z_t)
          x_dist = dist.StudentT(df, x_mu, x_log_var.exp()).to_event(1)

          pyro.sample('obs_X_{}'.format(t + 1), x_dist, obs=X[t, :, :13])

          missingness_p = self.missing(z_t)
          m_dist = dist.Bernoulli(missingness_p).to_event(1)
          pyro.sample('obs_I_{}'.format(t + 1), m_dist, obs=X[t, :, 13:13 + 12])

        z_pheno[t == L] = z_t[t == L]

        h_prev = h_t
        z_prev = z_t

    with pyro.plate('P_plate', 25, device=pheno.device):
      with pyro.plate('PX_plate', B, device=X.device):
        pheno_p = self.pheno(z_pheno)
        pheno_dist = dist.Bernoulli(pheno_p)

        pyro.sample('obs_P', pheno_dist, obs=pheno)

After turning validation on, I only get warnings about empty guide. Is that a concern here if I just want to see maximum likelihood working?

warnings.warn("Found vars in model but not guide: {}".format(model_vars - guide_vars - enum_vars))

The warning said: your model has latent sites which are not appeared in guide. Please make guide for these latent sites. In this case, they are Z_t.

Indeed. I just meant to do MAP estimation before doing variational inference.

Anyways, I’ve added the guide now and the validation reports no warnings/errors. Let me report how the loss trajectory looks like soon.

Thanks!

I updated the guide to look like this. Pyro validation throws no warnings/errors in this case.

  def guide(self, L, X, pheno):
    pyro.module('dmm', self, update_module_params=True)

    q_h_0 = pyro.param('q_h_0', torch.zeros(2, 1, self.latent_size, device=X.device))
    q_mu_0 = pyro.param('q_mu_0', torch.zeros(1, self.latent_size, device=X.device))
    q_log_sigma_0 = pyro.param('q_log_sigma_0', torch.zeros(1, self.latent_size, device=X.device))

    _, sorted_idx = L.sort(descending=True)
    L = L[sorted_idx]
    X = X[:, sorted_idx, :]

    B = pheno.size(0)
    T_max = X.size(0)

    q_h_0 = q_h_0.expand(-1, B, -1).contiguous()
    q_mu_0 = q_mu_0.expand(B, -1)
    q_log_sigma_0 = q_log_sigma_0.expand(B, -1)

    packed_X = torch.nn.utils.rnn.pack_padded_sequence(X[:, :, :13 + 12], lengths=L)
    rnn_out, _ = self.post_rnn(packed_X, q_h_0)
    rnn_out, _ = torch.nn.utils.rnn.pad_packed_sequence(rnn_out)

    with pyro.plate('Z_plate', B, device=X.device):
      z_prev = pyro.sample('Z_0', dist.Normal(q_mu_0, q_log_sigma_0.exp()).to_event(1))
      for t in range(T_max):
        with poutine.mask(mask=(t < L)):
          z_mu, z_log_var = self.post_transition(torch.cat((z_prev, rnn_out[t]), dim=-1))
          z_dist = dist.Normal(z_mu, z_log_var.exp()).to_event(1)

          z_t = pyro.sample('Z_{}'.format(t + 1), z_dist)

          z_prev = z_t

And my graph looks like this so far.

This is effectively the same local pattern (perhaps each minibatch) on a downward trend. I’m using a slow learning rate 4e-5. Otherwise I just start getting NaNs after a few gradient updates.

@fehiepsi I’m wondering if you have encountered such a graph trend? I’m just trying to gauge how to interpret (if possible) my ELBO curves. Thanks for the help!

@activatedgeek I get the fluctuation frequently. The reason is Pyro Trace_ELBO is a stochastic estimate, not an analytic one. It has a downward trend while converging to some local minimal. But I have no idea why the same pattern happened in your case. :worried: Is it ever-decreasing now?

Agree with the stochasticity of ELBO.

The dataset I am trying to explain is a collection of time-series. I suspect that the variance accumulates over the long-horizon trajectories and the model is taking time to learn those. It has been 24 hours now and is same slow downward trend, no signs of convergence yet.

Prior to this, I also used a learning rate of 1e-4 which had a much faster downward trend. However, after 8 epochs over the training set, it diverged and started giving undefined model likelihoods. Do you have some tricks to fix divergence here?

I think that specifying different lr for different types of parameters can help: SVI Part I: An Introduction to Stochastic Variational Inference in Pyro — Pyro Tutorials 1.8.4 documentation
Here your q_log_sigma_0 is initialized by 0. Setting it to a fix number -5 instead of a parameter can also help (mimic the behavior of MAP inference).