Organizing model components

Hi all,

I’m working on a pyro model that has a VAE component and a regression component. In order to test different regression models, I would like to be able to create classes representing these models, and a general VAE model that takes one of these regression models as argument. So I would like to write

R1 = RegressionModel1()
M1 = VAE(R1)

R2 = RegressionModel2()
M2 = VAE(R2)

What I can’t figure out is how to sample parameters in these regression models, how to define guides for them (manual or possibly AutoGuides), and how to combine these guides with the guide for the VAE. As an example, let’s forget about the VAE, and look at a simpler example

import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
from pyro.nn import PyroModule, PyroParam, PyroSample
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from tqdm import trange
import torch
import matplotlib.pyplot as plt

class RegressionModel(PyroModule):
    def __init__(self):
        super().__init__()
        
        self.a = PyroParam(torch.tensor(1.0))
        self.b = PyroParam(torch.tensor(0.0))
        
    def forward(self, x):
        yhat = self.a * x + self.b
        return yhat        


class Model(PyroModule):
    def __init__(self, regression_model):
        super().__init__()
        
        self.regression_model = regression_model
        self.sigma_vari_loc = PyroParam(torch.tensor(-1.0))
        self.sigma_vari_scale = PyroParam(torch.tensor(0.5), constraint=constraints.positive)
        
    def model(self, x, y):
        pyro.module("model", self)
        
        log_sigma = pyro.sample("log_sigma", dist.Normal(torch.tensor(0.0), torch.tensor(1.0)))
        
        yhat = self.regression_model(x)
        
        with pyro.plate("data", x.shape[0]):
            pyro.sample("y", dist.Normal(yhat, torch.exp(log_sigma)), obs=y)
               
    def guide(self, x, y):
        pyro.module("model", self)
        
        log_sigma = pyro.sample("log_sigma", dist.Normal(self.sigma_vari_loc, self.sigma_vari_scale))

    def sample_sigma(self, n):
        log_sigma = dist.Normal(self.sigma_vari_loc, self.sigma_vari_scale).sample((n,))
        return torch.exp(log_sigma)

I can now construct an object of class Model, and estimate the parameters a and b and the variational parameters sigma_vari_loc and sigma_vari_scale. The Model class does not have to know about a and b in the RegressionModel class, and so I could easily change the RegressionModel to something else. However, suppose that I want to estimate the distribution of a and b. How do I define a guide method, and how do I call this guide method from the guide method of Model? Or should I do this completely differently?

I used this code to estimate parameters

pyro.clear_param_store()

## sample some fake data
a_gt = 0.5
b_gt = 1.5
sigma_gt = 0.1
n = 1000

x = torch.randn((n,))
y = a_gt * x + b_gt + torch.randn((n,)) * sigma_gt

## create the regression model and model
R = RegressionModel()
M = Model(R)

optim = Adam({"lr": 1e-2})
loss = Trace_ELBO()

## construct an SVI object and train the model
svi = SVI(M.model, M.guide, optim, loss)
elbo = []
# training loop
for epoch in (pbar := trange(3000)):
    loss = svi.step(x, y)
    elbo.append((epoch, loss))
    pbar.set_description(f"average train loss: {loss:0.2f}")

## look at results
print("sigma_vari_loc:", M.sigma_vari_loc)
print("sigma_vari_scale:", M.sigma_vari_scale)
print("a:", M.sub_model.a)
print("b:", M.sub_model.b)

sigma = M.sample_sigma(1000).detach().numpy()
fig, ax = plt.subplots(1, 1)
ax.hist(sigma, 50, density=True)

Any suggestions are very welcome!

thanks,

Chris