Hello! I am working with Bayesian convolutional neural network to model the properties of sequences. From task A, I am able to obtain samples of the posterior of the model shown at the bottom of the post (which specifies a factorized Gaussian prior over all of the weights in the network). Now for task B, I want to use the estimates of parameter means and covariances of the posterior from task A to initialize the prior of the model with the same architecture. This is conceptually similar to what is shown here with Bayesian updating, the only difference being that I also want to be able to specify the covariance between parameters in a multivariate Gaussian over the full set of parameters. What would be the easiest way to accomplish this?
class bayesian_cnn(PyroModule):
def __init__(self,scale = 1):
super().__init__()
#first convolutional layer
self.conv1 = PyroModule[nn.Conv1d](in_channels = 1,
out_channels = 2,
kernel_size = 10,
stride=1,
padding='same')
self.conv1.weight = PyroSample(dist.Normal(0., scale).expand([2,1,10]).to_event(3))
self.conv1.bias = PyroSample(dist.Normal(0., scale).expand([2]).to_event(1))
self.flat = nn.Flatten()
self.dense1 = PyroModule[nn.Linear](in_features = 40,
out_features = 20 )
self.dense1.weight = PyroSample(dist.Normal(0., scale).expand([20, 40]).to_event(2))
self.dense1.bias = PyroSample(dist.Normal(0., scale).expand([20]).to_event(1))
self.dense2 = PyroModule[nn.Linear](in_features = 20,
out_features = 20 )
self.dense2.weight = PyroSample(dist.Normal(0., scale).expand([20, 20]).to_event(2))
self.dense2.bias = PyroSample(dist.Normal(0., scale).expand([20]).to_event(1))
self.dense_final = PyroModule[nn.Linear](in_features = 20,
out_features = 1)
self.dense_final.weight = PyroSample(dist.Normal(0., scale).expand([1, 20]).to_event(2))
self.dense_final.bias = PyroSample(dist.Normal(0., scale).expand([1]).to_event(1))
def forward(self,x,y=None):
x = torch.relu(self.conv1(x))
x = self.flat(x)
x = torch.relu(self.dense1(x))
x = torch.relu(self.dense2(x))
x = self.dense_final(x)
mu = x.squeeze()
sigma = pyro.sample("sigma", dist.Uniform(0., 1))
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.Normal(mu, sigma*sigma), obs=y)
return mu