Hi!
I’m using a BNN as a function prior and would like to plot realizations from it. From the literature I know that it should produce smooth functions if all the parameters are gaussian. This should be true before any training. However mine does not do that but instead it gives static noise as shown in the figure:
Any idea what would be the reason for this? I have three layers initialized like this:
self.layer1.weight = PyroSample(dist.Normal(0.,torch.tensor(0.01*float(np.sqrt(1/n_out))).expand([n_out,n_out]).to_event(2))
self.layer1.bias = PyroSample(dist.Normal(0.,
torch.tensor(0.01)).expand([n_out]).to_event(1))
Any help would be greatly appreciated
Updating with the full code of the BNN. I’m generating realizations before the training by calling the forward() function and giving it the vector t.
The realizations should look like the image attached (from this paper: https://arxiv.org/abs/2112.10663).
Could it be that I can’t get the values for weights and biases if it in PyroSample? If so, how could I go around this
class BNN(PyroModule):
def __init__(self, n_in, n_out, layers):
super().__init__()
self.n_layers = len(layers)
self.layers = PyroModule[torch.nn.ModuleList]([
PyroModule[nn.Linear](n_in, n_out)
for j in range(self.n_layers)
])
self.activations = []
for ii, layer in enumerate(layers):
# Scaling the weights, for gaussian n^(-1/2) and for cauchy n^-1
if ii == self.n_layers-1:
if layers[layer]['type'] == 'gaussian':
weight = layers[layer]['weight']*float(1/np.sqrt(n_out))
else:
weight = layers[layer]['weight']*float(1/n_out)
else:
weight = layers[layer]['weight']
bias = layers[layer]['bias']
if layers[layer]['type'] == 'cauchy':
self.layers[ii].weight = PyroSample(dist.Cauchy(0.,
torch.tensor(weight)).expand([n_out, n_out]).to_event(2))
self.layers[ii].bias = PyroSample(dist.Cauchy(0.,
torch.tensor(bias)).expand([n_out]).to_event(1))
elif layers[layer]['type'] == 'gaussian':
self.layers[ii].weight = PyroSample(dist.Normal(0.,
torch.tensor(weight)).expand([n_out, n_out]).to_event(2))
self.layers[ii].bias = PyroSample(dist.Normal(0.,
torch.tensor(bias)).expand([n_out]).to_event(1))
else:
print('Invalid layer!')
self.activations.append(layers[layer]['activation'])
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, t, A, y=None):
#t = t.reshape(-1, 1)
if self.activations[0] == 'tanh':
mu = self.tanh(self.layers[0](t))
elif self.activations[0] == 'relu':
mu = self.relu(self.layers[0](t))
else:
mu = self.layers[0](t)
for ii in range(1, self.n_layers-1):
if self.activations[ii] == 'tanh':
mu = self.tanh(self.layers[ii](mu))
elif self.activations[ii] == 'relu':
mu = self.relu(self.layers[ii](mu))
else:
mu = self.layers[ii](mu)
if self.activations[-1] == 'tanh':
mu = self.tanh(self.layers[-1](mu))
elif self.activations[-1] == 'relu':
mu = self.relu(self.layers[-1](mu))
else:
mu = self.layers[-1](mu)
y_hat = torch.matmul(A, mu)
sigma = pyro.sample("sigma", dist.Uniform(0.,
torch.tensor(0.01)))
with pyro.plate("data", n_y):
obs = pyro.sample("obs", dist.Normal(y_hat, sigma), obs=y)
return mu