Pyro.infer.importance.psis_diagnostic gives errors

I am not able to get PSIS to work.

re-producible example from the regression tut

import os
from functools import partial
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pyro
import pyro.distributions as dist

DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv"

data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")

df = data[["cont_africa", "rugged", "rgdppc_2000"]]

df = df[np.isfinite(df.rgdppc_2000)]

df["rgdppc_2000"] = np.log(df["rgdppc_2000"])

from pyro.nn import PyroSample

class BayesianRegression(PyroModule):

    def __init__(self, in_features, out_features):

        super().__init__()

        self.linear = PyroModule[nn.Linear](in_features, out_features)

        self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))

        self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))

    def forward(self, x, y=None):

        sigma = pyro.sample("sigma", dist.Uniform(0., 10.))

        mean = self.linear(x).squeeze(-1)

        with pyro.plate("data", x.shape[0]):

            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)

        return mean

from pyro.infer.autoguide import AutoDiagonalNormal

model = BayesianRegression(3, 1)

guide = AutoDiagonalNormal(model)

from pyro.infer import SVI, Trace_ELBO

adam = pyro.optim.Adam({"lr": 0.03})

svi = SVI(model, guide, adam, loss=Trace_ELBO())

pyro.clear_param_store()

num_iterations = 1500

for j in range(num_iterations):

    # calculate the loss and take a gradient step

    loss = svi.step(x_data, y_data)

    if j % 100 == 0:

        print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))

Now apply PSIS:

from pyro.infer import importance
k_hat = importance.psis_diagnostic(model,guide,x_data,y_data)

Error Summary below:

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 10D
                Trace Shapes:                            
                 Param Sites:                            
                Sample Sites:                            
num_particles_vectorized dist                       |    
                        value                  1000 |    
                   sigma dist 1000 1 1 1 1 1 1    1 |    
                        value 1000 1 1 1 1 1 1    1 |    
           linear.weight dist 1000 1 1 1 1 1 1    1 | 1 3
                        value 1000 1 1 1 1 1 1    1 | 1 3
             linear.bias dist 1000 1 1 1 1 1 1    1 | 1  
                        value 1000 1 1 1 1 1 1    1 | 1  

Can you please help me with this error?
Thanks!

I found this issue which is exactly the same as mine:

Apologies, should have searched harder!

As a follow up to that discussion:

  • Since importance.psis_diagnostic cannot be used with pytorch modules, how else can we test for the convergence of the SVI?

well it’s not that it can’t be used with modules. it’s that nn.Linear isn’t fully broadcastable. get rid of nn.Linear and there’s no issue

Thanks @martinjankowiak
nn.Linear is a very common module in most DNNs. This unfortunately invalidates the use of PSIS for most DNN based use cases.

Is there something else in the literature that can be used to check for ‘good ness of fit’ of the guide?
I can try my hand at implement those (in case pyro already hasn’t).

Maybe it’s easiest to add broadcasting support around the self.linear(-) call. Would the following work around nn.Linear's limitation?

- mean = self.linear(x).squeeze(-1)
+ batch_shape = x.shape[:-1]
+ x = x.reshape(-1, x.size(-1))
+ mean = self.linear(x).reshape(batch_shape)

HI @fritzo - I did a quick substitution and … got the same error. :frowning:
Colab

the problem is in broadcasting the weights/bias. what’s needed is 5-10 lines of code that defines a custom pytorch module that does a linear op without assuming e.g. that the weights are given 2-dimensional tensors

Thanks @martinjankowiak .
I plan to introduce BNNs with Pyro to a bunch of people at work and this might be a little complicated to communicate.
Just out of curiosity, what is the importance class used for?

Importance as the name suggests is used for various kinds of importance sampling, see e.g. CSIS.

for BNNs in pyro we generally recommend to use TyXe, which is built on top of pyro. using pyro directly in the BNN context isn’t likely to work very well unless the user has sufficient technical expertise.