Fast posterior point estimates without Predictive

Once again, I’m working on the model described in AutoDiagonalNormal found no latent variables; Use an empty guide instead - Pyro Discussion Forum.

I’m looking for the fastest ways to get point estimates like MAE and MSE on the validation set during training. My evaluation function currently is shown below. I use Predictive to generate samples from posterior and take their mean to calculate MAE and MSE. I know I can parallelize Predictive but I’m looking for a fast alternative (as I still need to learn how to update my code to properly run Predictive in parallel). My thinking is that I can take the MAP estimate for all of my parameters to get a point value prediction for each observation. But how exactly can I do that? Do I need to declare a separate AutoDelta guide and use that?

def evaluate(guide, model, criterion, val_dataloader_iter, validation_steps, device, metric_agg_fn=None):
  model.eval()  # Set model to evaluate mode

  predictive_obs = Predictive(model, guide=guide, num_samples=200, return_sites = ['obs'])
  # statistics
  running_mae_loss = 0.0
  running_mse_loss = 0.0
  running_elbo = 0.0

  # Iterate over all the validation data.
  for step in range(validation_steps):
      pd_batch = next(val_dataloader_iter)
      pd_batch['features'] = torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1)#.double()    
      inputs = pd_batch['features'].to(device)
      labels = pd_batch[y_name].to(device)
      samples_obs = predictive_obs(inputs)
      mae_loss = torch.absolute(torch.mean(samples_obs['obs'], dim = 0) - labels).mean()
      mse_loss = torch.pow(torch.mean(samples_obs['obs'], dim = 0) - labels, 2).mean()
      running_mae_loss += mae_loss
      running_mse_loss += mse_loss
      running_elbo += svi.evaluate_loss(inputs, labels)#elbo_loss.loss(model, guide, inputs, labels)

  # The losses are averaged across observations for each minibatch.
  epoch_mae_loss = running_mae_loss / validation_steps
  epoch_mse_loss = running_mse_loss / validation_steps
  elbo_per_life = running_elbo / df_val.count()

  return epoch_mae_loss, epoch_mse_loss, elbo_per_life

If you are using one of Pyro’s autoguides you can get cheap point estimates of latent variables via guide.median() . You can turn those latent variables into posterior predictives via poutine.condition():

latents = guide.median()
posterior_predictive = poutine.trace(
    poutine.condition(model, latents)
).get_trace()

or if you have a good return statement in your model you can simply

latents = guide.median()
value = poutine.condition(model, latents)()

I’m running into a weird bug where if I add this to a model evaluation step, my model training fails. This is a continuation of the model that I’ve posted about several times now :slight_smile:

def train_and_evaluate_SVI(svi, criterion, model, guide, bs, ne, lr=0.001):
 
  model = model.to(device)
  with converter_train.make_torch_dataloader(batch_size=bs, shuffling_queue_capacity = 0) as train_dataloader, \
  converter_val.make_torch_dataloader(batch_size=bs, shuffling_queue_capacity = 0) as val_dataloader:

      train_dataloader_iter = iter(train_dataloader)
      steps_per_epoch = len(converter_train) // bs

      val_dataloader_iter = iter(val_dataloader)
      validation_steps = np.max((1, len(converter_val) // bs_val))

      for epoch in range(ne):
          if (epoch + 1) % 10 == 0: print('-' * 10)
          if (epoch + 1) % 10 == 0: print('Epoch {}/{}'.format(epoch + 1, ne))
          
          train_loss = train_one_epoch_SVI(svi, train_dataloader_iter, steps_per_epoch, epoch, device)
          
          val_loss = evaluate(guide, model, criterion, val_dataloader_iter, validation_steps, epoch, device) 

Here is my evaluate function. My goal is to use the fast approximation from latents for most epochs and calculate metrics from the predictive posterior every 10 epochs.

def evaluate(guide, model, criterion, val_dataloader_iter, validation_steps, epoch, device, metric_agg_fn=None):
    model.eval()  # Set model to evaluate mode

  predictive_obs = Predictive(model, guide=guide, num_samples=int(10e2), return_sites = ['obs'])
  # statistics
  running_mae_loss = 0.0
  running_mse_loss = 0.0
  running_elbo = 0.0
  total_lives = 0
  latents = guide.median()
  samples_obs = None
    
  # Iterate over all the validation data.
  for step in range(validation_steps):
      pd_batch = next(val_dataloader_iter)
      pd_batch['features'] = torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1)
      inputs = pd_batch['features'].to(device)
      labels = pd_batch[y_name].to(device)
      if (epoch + 1) % EVAL_EVERY_N_EPOCH == 0:
        samples_obs = predictive_obs(inputs, subsample=False)
        sampled_mean = torch.mean(samples_obs['obs'], dim = 0)
        mae_loss = torch.absolute(sampled_mean - labels).mean()
        mse_loss = torch.pow(sampled_mean - labels, 2).mean()
      else:
        # to get a fast estimate of the posterior, we use the median of the variational distribution
        values = poutine.condition(model, latents)(inputs, y=None, subsample=False)        
        # mean = mu * (1 - theta)
        est_mean = 0#values.reshape((NUM_PARAMS, -1))[0] * (1 - values.reshape((NUM_PARAMS, -1))[2]) # model.class_mean(values)#
        mae_loss = 0#torch.absolute(est_mean - labels).mean()
        mse_loss = 0#torch.mean(torch.pow(est_mean - labels, 2))
      #print(samples_obs['obs'].size())
      
      running_mae_loss += mae_loss
      running_mse_loss += mse_loss
      running_elbo += svi.evaluate_loss(inputs, labels)
      total_lives += inputs.shape[0]

  # The losses are averaged across observations for each minibatch.
  epoch_mae_loss = running_mae_loss  / validation_steps
  epoch_mse_loss = running_mse_loss / validation_steps
  elbo_per_life = running_elbo / total_lives

  print(f'Validation Loss: MAE = {epoch_mae_loss:.2f}, MSE = {epoch_mse_loss:.2f}, ELBO = {elbo_per_life:.3f}')
  return epoch_mae_loss, epoch_mse_loss, elbo_per_life, samples_obs

To get my model to run, I have to set est_mean = 0. Using the actual value: values.reshape((NUM_PARAMS, -1))[0] * (1 - values.reshape((NUM_PARAMS, -1))[2]) causes training to fail with invalid estimates:

Expected parameter loc (Parameter of shape (99,)) of distribution Normal(loc: torch.Size([99]), scale: torch.Size([99])) to satisfy the constraint Real(), but found invalid values:

Literally just switching between 0 and the actual calculation causes it to fail. What exactly could be going on here?

Actually the problem isn’t limited to just that calculation. Adding just a print(model.training) to train_one_epoch_SVI causes the same error. This problem just started today. I’m running in Databricks

Tried downgrading to version 1.8.0. Same issue. The only way I can get it to train when I add new lines is to restart the Databricks cluster completely. Clearing the session without restarting the server does not work.

Hmm sounds like it might simply be a spurious random error. Are you fixing your rng seed before starting training? The solution is probably to .clamp(min=...) various places in your model. Note that evaluation code (in your comments) will change the randomness source. To avoid the differences with vs without evaluation, you should fork the rng using with torch.random.fork_rng(). Additionally I’d recommend torch.no_grad() and poutine.block() to minimize side effects of evaluation.

with torch.random.fork_rng(), torch.no_grad(), poutine.block():
    val_loss = evaluate(guide, model, criterion, val_dataloader_iter, validation_steps, epoch, device)

Thank you for the suggestions! I was not familiar with fork_rng or poutine.block so I’ve added your line to my code. I already have .clamp(min=...) to keep my parameters under control and use ClippedAdam to keep my gradients under control. I’m also using pyro.set_rng_seed(123456789) for my seed.

I, however, still get the same error unless I restart the cluster completely after making any changes. The weirdest thing is that this just started happening today. Yesterday, I ran a few dozen different configurations of models with various code, distribution, etc changes and didn’t get a single error.