Specify the covariance of a multivariate Gaussian prior for a bayesian neural network

Hello! I am working with Bayesian convolutional neural network to model the properties of sequences. From task A, I am able to obtain samples of the posterior of the model shown at the bottom of the post (which specifies a factorized Gaussian prior over all of the weights in the network). Now for task B, I want to use the estimates of parameter means and covariances of the posterior from task A to initialize the prior of the model with the same architecture. This is conceptually similar to what is shown here with Bayesian updating, the only difference being that I also want to be able to specify the covariance between parameters in a multivariate Gaussian over the full set of parameters. What would be the easiest way to accomplish this?

class bayesian_cnn(PyroModule):
  def __init__(self,scale = 1):
      super().__init__()
      #first convolutional layer
      self.conv1 = PyroModule[nn.Conv1d](in_channels = 1, 
                                         out_channels = 2,
                                         kernel_size = 10,
                                         stride=1,
                                         padding='same')
      self.conv1.weight = PyroSample(dist.Normal(0., scale).expand([2,1,10]).to_event(3))
      self.conv1.bias   = PyroSample(dist.Normal(0., scale).expand([2]).to_event(1))
      
      self.flat = nn.Flatten()
      
      self.dense1 = PyroModule[nn.Linear](in_features = 40,
                                          out_features = 20 )
      self.dense1.weight = PyroSample(dist.Normal(0., scale).expand([20, 40]).to_event(2))
      self.dense1.bias   = PyroSample(dist.Normal(0., scale).expand([20]).to_event(1))
      
      self.dense2 = PyroModule[nn.Linear](in_features = 20,
                                          out_features = 20 )
      self.dense2.weight = PyroSample(dist.Normal(0., scale).expand([20, 20]).to_event(2))
      self.dense2.bias   = PyroSample(dist.Normal(0., scale).expand([20]).to_event(1))
      
      self.dense_final = PyroModule[nn.Linear](in_features  = 20,
                                   out_features = 1)
      self.dense_final.weight = PyroSample(dist.Normal(0., scale).expand([1, 20]).to_event(2))
      self.dense_final.bias   = PyroSample(dist.Normal(0., scale).expand([1]).to_event(1))
      
  def forward(self,x,y=None):
      
      x = torch.relu(self.conv1(x))
      x = self.flat(x)
      x = torch.relu(self.dense1(x))
      x = torch.relu(self.dense2(x))
      x = self.dense_final(x)

      mu = x.squeeze()
      sigma = pyro.sample("sigma", dist.Uniform(0., 1))
      
      with pyro.plate("data", x.shape[0]):
          obs = pyro.sample("obs", dist.Normal(mu, sigma*sigma), obs=y)
          
      return mu