Hi!

I’m tying to use a BNN as a prior. I was able to train the BNN and create pyro.infer.Predictive object out of it. I can sample the object to make predictions.

However when I try to use this as a prior for NUTS sampler, it doesn’t work. I’m a bit confused on how to do the sampling, since my BNN needs the x-values as an input.

I tried to do it like this:

```
omega = pyro.sample("omega", prior(torch.linspace(domain[0], domain[1], n)))
```

where prior() is the Predictive object. It works when I just call it in a code with some x-values.

However when trying to use it as a prior in MCMC it gives this error:

```
RuntimeError: Multiple sample sites named 'layer1.weight'
Trace Shapes:
Param Sites:
Sample Sites:
Trace Shapes:
Param Sites:
Sample Sites:
_num_predictive_samples dist |
value 50 |
layer1.weight dist | 5 1
value | 5 1
layer1.bias dist | 5
value | 5
layer2.weight dist | 1 5
value | 1 5
layer2.bias dist | 1
value | 1
sigma dist |
value |
data dist |
value 200 |
obs dist 200 |
value 200 |
```

The BNN is copied from Pyro examples and looks like this:

```
class BNN(PyroModule):
def __init__(self, in_dim=1, out_dim=1, hid_dim=5, prior_scale=10.):
super().__init__()
self.activation = nn.Tanh() # or nn.ReLU()
self.layer1 = PyroModule[nn.Linear](in_dim, hid_dim) # Input to hidden layer
self.layer2 = PyroModule[nn.Linear](hid_dim, out_dim) # Hidden to output layer
# Set layer parameters as random variables
self.layer1.weight = PyroSample(dist.Normal(0., prior_scale).expand([hid_dim, in_dim]).to_event(2))
self.layer1.bias = PyroSample(dist.Normal(0., prior_scale).expand([hid_dim]).to_event(1))
self.layer2.weight = PyroSample(dist.Normal(0., prior_scale).expand([out_dim, hid_dim]).to_event(2))
self.layer2.bias = PyroSample(dist.Normal(0., prior_scale).expand([out_dim]).to_event(1))
def forward(self, x, y=None):
x = x.reshape(-1, 1)
x = self.activation(self.layer1(x))
mu = self.layer2(x).squeeze()
sigma = pyro.sample("sigma", dist.Gamma(.5, 1)) # Infer the response noise
# Sampling model
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.Normal(mu, sigma * sigma), obs=y)
return mu
```

Thanks in advance!