Using a pyro.infer.Predictive() as a prior for a NUTS model


I’m tying to use a BNN as a prior. I was able to train the BNN and create pyro.infer.Predictive object out of it. I can sample the object to make predictions.

However when I try to use this as a prior for NUTS sampler, it doesn’t work. I’m a bit confused on how to do the sampling, since my BNN needs the x-values as an input.
I tried to do it like this:

omega = pyro.sample("omega", prior(torch.linspace(domain[0], domain[1], n)))

where prior() is the Predictive object. It works when I just call it in a code with some x-values.

However when trying to use it as a prior in MCMC it gives this error:

RuntimeError: Multiple sample sites named 'layer1.weight'
Trace Shapes:
 Param Sites:
Sample Sites:
               Trace Shapes:          
                Param Sites:          
               Sample Sites:          
_num_predictive_samples dist     |    
                       value  50 |    
          layer1.weight dist     | 5 1
                       value     | 5 1
            layer1.bias dist     | 5  
                       value     | 5  
          layer2.weight dist     | 1 5
                       value     | 1 5
            layer2.bias dist     | 1  
                       value     | 1  
                  sigma dist     |    
                       value     |    
                   data dist     |    
                       value 200 |    
                    obs dist 200 |    
                       value 200 |    

The BNN is copied from Pyro examples and looks like this:

class BNN(PyroModule):
    def __init__(self, in_dim=1, out_dim=1, hid_dim=5, prior_scale=10.):

        self.activation = nn.Tanh()  # or nn.ReLU()
        self.layer1 = PyroModule[nn.Linear](in_dim, hid_dim)  # Input to hidden layer
        self.layer2 = PyroModule[nn.Linear](hid_dim, out_dim)  # Hidden to output layer

        # Set layer parameters as random variables
        self.layer1.weight = PyroSample(dist.Normal(0., prior_scale).expand([hid_dim, in_dim]).to_event(2))
        self.layer1.bias = PyroSample(dist.Normal(0., prior_scale).expand([hid_dim]).to_event(1))
        self.layer2.weight = PyroSample(dist.Normal(0., prior_scale).expand([out_dim, hid_dim]).to_event(2))
        self.layer2.bias = PyroSample(dist.Normal(0., prior_scale).expand([out_dim]).to_event(1))

    def forward(self, x, y=None):
        x = x.reshape(-1, 1)
        x = self.activation(self.layer1(x))
        mu = self.layer2(x).squeeze()
        sigma = pyro.sample("sigma", dist.Gamma(.5, 1))  # Infer the response noise

        # Sampling model
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mu, sigma * sigma), obs=y)
        return mu

Thanks in advance!

NUTS requires a density that can be computed pointwise and differentiated. not really clear what you’re trying to do

I realized I had understood Pyro workflow wrong and the question didn’t make sense.