Converting neural network to pyro and prediction

My question is more about prediction but feel free to criticise how I’ve written my pyro code. I have the following linear model, of which I wish to put priors on weights and biases:

class Model(nn.Module):
    def __init__(self, dims):
        self.linear = nn.Linear(dims, 1)
    def forward(self, x):
        return self.linear(x)
model = Model(1)

I was following this blog on how to convert NNets to pyro and get the following code:

def bayes_model(x_data, y_data):
    # define priors
    get_params = lambda w: (torch.zeros_like(w), torch.ones_like(w))
    priors = {name:Normal(*get_params(w)) for name, w in 
    # lift onto a random pyro module
    lifted_module = pyro.random_module("module", model, priors)
    lifted_reg_model = lifted_module()
    # define rest of model with likelihood
    yhat = lifted_reg_model(x_data)
    scale = pyro.sample("sigma", Uniform(0, 5))
    pyro.sample("obs", Laplace(yhat, scale), obs=y_data)
    return yhat

guide = AutoDiagonalNormal(bayes_model)
optim = Adam({"lr": 0.03})
svi = SVI(bayes_model, guide, optim, loss=Trace_ELBO(), num_samples=1000)

And I train the pyro model with:

epochs = 15
for _ in range(epochs):
    loss = 0
    for x_data, y_data in train_dl:
        loss += svi.step(x_data, y_data)
    print(f"loss: {loss / len(train_data):.4f}")

The question is how do I predictive inference, now that I have a trained pyro model?

A minimal working example is included in this colab notebook.

The following was my feeble attempt at doing predictive inference but I’m clearly doing it wrong:

num_samples = 10
def predict(x):
    sampled_models = [guide(None, None) for _ in range(num_samples)]
    yhats = [model(x).data for model in sampled_models]
    return yhats

pred_y_train = predict(torch.Tensor(x))

@sachinruk I think you can find the answer in . :slight_smile:

Hi @sachinruk, @fehiepsi’s pointer is great. Note that the blog post is outdated. Pyro 1.0 now includes a PyroModule class to help make your nn.Modules Bayesian. The new syntax is:

+ def normal_like(x):
+     return dist.Normal(0, 1).expand(x.shape).to_event(x.dim())

- class Model(nn.Module):
+ class Model(PyroModule):
      def __init__(self, dims):
-         self.linear = nn.Linear(dims, 1)
+         self.linear = PyroModule[nn.Linear](dims, 1)
+         self.linear.weight = PyroSample(normal_like(self.linear.weight))
+         self.linear.bias = PyroSample(normal_like(self.linear.bias))
      def forward(self, x):
          return self.linear(x)