AutoGuide for new data in a hierarchical model


I’m struggling in finding a way of using an AutoGuide to train on new data given another AutoGuide trained on the main dataset. In this simplified example, I assumed I have a hierarchy of linear regression slope coefficients:

  • beta - coefficients defining individual linear regression task; each task/coefficient has a corresponding type,
  • mu_beta_type - coefficients defining mean of all betas belonging to that type
  • mu_beta_global - coefficient defining mean of all mu_beta_type params

I’m trying to find a solution for the following problems:

  1. Assuming I have a test dataset with new tasks (corresponding to new beta coefficients) with known types, how can I construct an AutoGuide that will only infer posterior for the new beta coefficients while keeping posterior of all other relevant latent sites (i.e. mu_beta_type) fixed to the posterior trained on the train dataset?
  2. How can I extend the above to a test dataset with a new task with a new type? I.e. I would like to infer new beta for the new task (and possibly mu_beta_type for the new type although that shouldn’t matter at this stage) while keeping fixed posterior for mu_beta_global.

To address problem 1, I tried subclassing AutoNormalMessanger and modifying get_posterior method to return posterior of the original guide for latent sites that should be shared but that doesn’t seem like a correct solution…

I would greatly appreciate getting some guidance on how to implement this correctly using AutoGuides. Below is the demo code.

import numpy as np
import pyro
import pyro.distributions as dist
import torch
from pyro.infer import Predictive
from pyro.infer.autoguide import AutoNormalMessenger

def model(data):
    X = data['x']
    mu_beta_global = pyro.sample('mu_beta_global', dist.Normal(0, 3))

    with pyro.plate('types', data['n_types']):
        mu_beta_type = pyro.sample('mu_beta_type', dist.Normal(mu_beta_global, 1))

    with pyro.plate('tasks', data['n_tasks']):
        beta = pyro.sample('beta_task', dist.Normal(mu_beta_type[data['task_to_type']], 0.5))

    with pyro.plate('data', len(X)):
        obs_mu = beta[data['obs_tasks']] * X
        return pyro.sample('obs', dist.Normal(obs_mu, 0.3), obs=data.get('obs'))

def get_local_guide(full_guide, infer_sites):
    class LocalGuide(AutoNormalMessenger):
        def get_posterior(self, name, prior):
            if name not in infer_sites:
                return full_guide.get_posterior(name, prior)
            return super().get_posterior(name, prior)
    return LocalGuide

def generate_data(n_types, n_tasks):
    min_data_per_task = 4
    max_data_per_task = 30
    x_dist = dist.Uniform(-3, 3)
    xs = []
    tasks = []
    task_to_type = []
    for i in range(n_tasks):
        task_x = x_dist.sample((np.random.randint(min_data_per_task, max_data_per_task),))
        tasks += [i] * len(task_x)

    return {
        'x': torch.concat(tuple(xs)),
        'obs_tasks': torch.tensor(tasks),
        'task_to_type': torch.tensor(task_to_type),
        'n_tasks': n_tasks,
        'n_types': n_types,

train_data = generate_data(n_types=4, n_tasks=50)
test_data = generate_data(n_types=4, n_tasks=1)
betas = pyro.poutine.trace(model).get_trace(train_data).nodes['beta_task']['value']  # sample betas from prior
obs_model =, data={'beta_task': betas})
train_data['obs'] = obs_model(train_data)
test_data['obs'] = obs_model(test_data)

def train(model, guide, data):
    optimizer = torch.optim.Adam
    scheduler = pyro.optim.MultiStepLR({'optimizer': optimizer, 'optim_args': {'lr': 0.1}, 'milestones': [1000, 2000, 3000]})
    elbo = pyro.infer.Trace_ELBO()
    svi = pyro.infer.SVI(model, guide, scheduler, elbo)
    losses = []
    for step in range(4000):
        loss = svi.step(data) / data['x'].shape[0]
        if step % 500 == 0:
            print("{}: Elbo loss: {}".format(step, loss))

guide = AutoNormalMessenger(model)
train(model, guide, train_data)

local_guide = get_local_guide(guide, infer_sites=['beta_task'])(model)
train(model, local_guide, test_data)

hello @maciejk

assuming your trainable parameters are all in the guide you can probably do something like this:

  • use the “lower level” training pattern in this tutorial. in particular you can create optimizers that only target the parameters of interest in the guide you want to update:

    optimizer1 = torch.optim.Adam(guide1.parameters(), lr=0.001)

  • use AutoGuideList to combine different guides (using poutine.block as needed to define local/sub-guides on the right subsets of latent variables):

guide = AutoGuideList(my_model)
guide1 = AutoNormal(poutine.block(model, hide=["local"]))
guide2 = AutoNormal(poutine.block(model, expose=["local"]))

Hi @martinjankowiak

Thanks a lot for the reply! It sounds like a very good solution to my problem. I tested it (with a couple modifications to deal with some errors) and it appears to work well. One thing I observed is that the loss will in many cases fluctuate more on the test/new data than on the train data even though inferences of new local parameters look very stable. It may be that some of the shared parameters are not supported by test data (e.g. no tasks with certain types) while they are still included in the calculation of the ELBO. Is there some way to exclude the shared sites from ELBO calculation (as this should not have impact on the optimization objective assuming independence of shared and local posteriors)?

For reference, the modified demo code is as follows:

train_data = generate_data(n_types=10, n_tasks=50)
test_data = generate_data(n_types=1, n_tasks=5)
test_data['n_types'] = train_data['n_types']
mu_beta_type = pyro.poutine.trace(model).get_trace(train_data).nodes['mu_beta_type']['value']  # sample betas from prior
obs_model =, data={'mu_beta_type': mu_beta_type})
train_data['obs'] = obs_model(train_data)
test_data['obs'] = obs_model(test_data)

def train(model, guide, data, optimize_params=None):
    loss_fn = lambda model, guide: pyro.infer.Trace_ELBO().differentiable_loss(model, guide, data)
    if optimize_params is None:
        optimize_params = guide.parameters()
    optimizer = torch.optim.Adam(optimize_params, lr=0.01)
    for i in range(6000):
        loss = loss_fn(model, guide) / data['x'].shape[0]
        if i % 100 == 0:
            print(f"{i}: Elbo loss: {loss:.3f}")

guide = AutoGuideList(model)
guide_shared = AutoNormal(poutine.block(model, hide=["beta_task"]))
guide_local = AutoNormal(poutine.block(model, expose=["beta_task"]))

train(model, guide, train_data)

guide_shared_test = AutoNormal(poutine.block(model, hide=["beta_task"]))
guide_local_test = AutoNormal(poutine.block(model, expose=["beta_task"]))
guide_test = AutoGuideList(model)
train(model, guide_test, test_data, optimize_params=guide_local_test.parameters())

i’m not sure what kind of fluctuations you mean but i’ll note that this is a moderately high learning rate:

i’m not sure if there is any easy way to exclude terms from the ELBO if you’re using autoguides, since that choice limits your control over the guide to some extent. with a custom model/guide pair you can mask distributions so that their log probabilities do not enter into ELBO calculations (of course in general this will affect learning):

pyro.sample("latent", my_dist.mask(False))

I don’t think it’s learning rate problem as I set the same value both for training shared (and local) parameters on train data as well as just the local parameters on new data and in the former case I don’t see these fluctuactions. Just to make it more concrete, let me show some of the last epochs (using a lower LR of 0.001) for:

  1. training all parameters on train data (10 task types, 50 tasks)
15000: Elbo loss: 0.391
15100: Elbo loss: 0.389
15200: Elbo loss: 0.389
15300: Elbo loss: 0.388
15400: Elbo loss: 0.389
15500: Elbo loss: 0.390
15600: Elbo loss: 0.390
15700: Elbo loss: 0.389
15800: Elbo loss: 0.391
15900: Elbo loss: 0.389
  1. training only local parameters on new data (50 new tasks with the same type, learning that 9 types won’t be )
15000: Elbo loss: 0.395
15100: Elbo loss: 0.385
15200: Elbo loss: 0.394
15300: Elbo loss: 0.394
15400: Elbo loss: 0.388
15500: Elbo loss: 0.411
15600: Elbo loss: 0.385
15700: Elbo loss: 0.385
15800: Elbo loss: 0.384
15900: Elbo loss: 0.422

So it doesn’t seem like the discrepancy is produced by random chance or high learning rate.

Anyway, fair enough with autoguides and limited control over them and ELBO calculations.