Hi,
I’m struggling in finding a way of using an AutoGuide to train on new data given another AutoGuide trained on the main dataset. In this simplified example, I assumed I have a hierarchy of linear regression slope coefficients:
- beta - coefficients defining individual linear regression task; each task/coefficient has a corresponding type,
- mu_beta_type - coefficients defining mean of all betas belonging to that type
- mu_beta_global - coefficient defining mean of all mu_beta_type params
I’m trying to find a solution for the following problems:
- Assuming I have a test dataset with new tasks (corresponding to new beta coefficients) with known types, how can I construct an AutoGuide that will only infer posterior for the new beta coefficients while keeping posterior of all other relevant latent sites (i.e. mu_beta_type) fixed to the posterior trained on the train dataset?
- How can I extend the above to a test dataset with a new task with a new type? I.e. I would like to infer new beta for the new task (and possibly mu_beta_type for the new type although that shouldn’t matter at this stage) while keeping fixed posterior for mu_beta_global.
To address problem 1, I tried subclassing AutoNormalMessanger and modifying get_posterior method to return posterior of the original guide for latent sites that should be shared but that doesn’t seem like a correct solution…
I would greatly appreciate getting some guidance on how to implement this correctly using AutoGuides. Below is the demo code.
import numpy as np
import pyro
import pyro.distributions as dist
import torch
from pyro.infer import Predictive
from pyro.infer.autoguide import AutoNormalMessenger
def model(data):
X = data['x']
mu_beta_global = pyro.sample('mu_beta_global', dist.Normal(0, 3))
with pyro.plate('types', data['n_types']):
mu_beta_type = pyro.sample('mu_beta_type', dist.Normal(mu_beta_global, 1))
with pyro.plate('tasks', data['n_tasks']):
beta = pyro.sample('beta_task', dist.Normal(mu_beta_type[data['task_to_type']], 0.5))
with pyro.plate('data', len(X)):
obs_mu = beta[data['obs_tasks']] * X
return pyro.sample('obs', dist.Normal(obs_mu, 0.3), obs=data.get('obs'))
def get_local_guide(full_guide, infer_sites):
class LocalGuide(AutoNormalMessenger):
def get_posterior(self, name, prior):
if name not in infer_sites:
return full_guide.get_posterior(name, prior)
return super().get_posterior(name, prior)
return LocalGuide
def generate_data(n_types, n_tasks):
min_data_per_task = 4
max_data_per_task = 30
x_dist = dist.Uniform(-3, 3)
xs = []
tasks = []
task_to_type = []
for i in range(n_tasks):
task_x = x_dist.sample((np.random.randint(min_data_per_task, max_data_per_task),))
xs.append(task_x)
tasks += [i] * len(task_x)
task_to_type.append(np.random.randint(n_types))
return {
'x': torch.concat(tuple(xs)),
'obs_tasks': torch.tensor(tasks),
'task_to_type': torch.tensor(task_to_type),
'n_tasks': n_tasks,
'n_types': n_types,
}
train_data = generate_data(n_types=4, n_tasks=50)
test_data = generate_data(n_types=4, n_tasks=1)
betas = pyro.poutine.trace(model).get_trace(train_data).nodes['beta_task']['value'] # sample betas from prior
obs_model = pyro.poutine.do(model, data={'beta_task': betas})
train_data['obs'] = obs_model(train_data)
test_data['obs'] = obs_model(test_data)
def train(model, guide, data):
pyro.clear_param_store()
optimizer = torch.optim.Adam
scheduler = pyro.optim.MultiStepLR({'optimizer': optimizer, 'optim_args': {'lr': 0.1}, 'milestones': [1000, 2000, 3000]})
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(model, guide, scheduler, elbo)
losses = []
for step in range(4000):
loss = svi.step(data) / data['x'].shape[0]
losses.append(loss)
if step % 500 == 0:
print("{}: Elbo loss: {}".format(step, loss))
scheduler.step()
guide = AutoNormalMessenger(model)
train(model, guide, train_data)
local_guide = get_local_guide(guide, infer_sites=['beta_task'])(model)
train(model, local_guide, test_data)