BNN Training successed using AutoDiagonalNormal(), but failed using AutoNormal()

Hello,
I am working on a simple Bayesian neural network model.
I am interested in understanding the value of the mean and variance of the weights and biases in each layer of a Bayesian Neural Network.
First,I successfully trained the model using AutoDiagonalNormal, but I found that this method does not allow me to know the mean and variance of the weights in each layer.
Then, I replace AutoDiagonalNormal() with AutoNormal(), so that I can get the mean and variance of the weights in each layer using pyro.get_param_store().items() method. However, when I used AutoNormal(), the training did not go well. The model did not predict the training data accurately.
According to Pyro’s official documentation, AutoDiagonalNormal() and AutoNormal() are supposed to be equivalent. Can you think of any reasons why I succeeded with AutoDiagonalNormal() but failed with AutoNormal()?
Alternatively, can I achieve my goal by using AutoDiagonalNormal() and outputting the weight parameters in such a way that I can identify which parameters belong to which layer?

Below is the code in question.
(This code is based on the book “pythonで始めるベイズ機械学習入門”)

I generated the training data using the following code.

N=30

def make_data(x,eps):
    y=10*np.sin(3*x) * np.exp(-x**2)
    noise = np.random.normal(0,eps,size=x.shape[0])
    return y+noise

x_data = np.random.uniform(low=-2,high=2,size=N)
y_data = make_data(x_data,2.0)

#show data points and true function

x_linspace = np.linspace(-2.,2, 1000)
y_linspace = make_data(x_linspace,0.0)

plt.plot(x_data, y_data, 'o',markersize=2, label='data');
plt.plot(x_linspace, y_linspace, label='true_func')
plt.legend()

plt.show()

And here is my BNN model.

#model
h1, h2 = 10, 10

class Model(PyroModule):
    def __init__(self, h1=h1, h2=h2):
        super().__init__()
        self.fc1 = PyroModule[nn.Linear](1, h1)
        self.fc1.weight = PyroSample(dist.Normal(0.,10.).expand([h1, 1]).to_event(2))
        self.fc1.bias = PyroSample(dist.Normal(0.,10.).expand([h1]).to_event(1))

        self.fc2 = PyroModule[nn.Linear](h1, h2)
        self.fc2.weight = PyroSample(dist.Normal(0.,10.).expand([h2, h1]).to_event(2))
        self.fc2.bias = PyroSample(dist.Normal(0.,10.).expand([h2]).to_event(1))

        self.fc3 = PyroModule[nn.Linear](h2, 1)
        self.fc3.weight = PyroSample(dist.Normal(0.,10.).expand([1, h2]).to_event(2))
        self.fc3.bias = PyroSample(dist.Normal(0.,10.).expand([1]).to_event(1))
        self.relu = nn.ReLU()

  
    def forward(self, X, Y=None, h1=h1, h2=h2):
        # ニューラルネットワークの出力
        X = self.relu(self.fc1(X))
        X = self.relu(self.fc2(X))
        mu = self.fc3(X)
     
        sigma = pyro.sample("sigma", dist.Uniform(0.,2.0))
    
        with pyro.plate("data", X.shape[0]):
          
          obs = pyro.sample("Y", dist.Normal(mu, sigma).to_event(1), obs=Y)
        return mu
        
model = Model(h1=h1, h2=h2)

Here is an inference code.
I replaced ‘guide=AutoDiagonalNormal(model)’ with ‘guide=AutoNormal(model)’.

pyro.clear_param_store()

guide=AutoDiagonalNormal(model)

adam = pyro.optim.Adam({'lr':0.0005})

svi = SVI(model,guide,adam, loss=Trace_ELBO())

x_data_torch = torch.from_numpy(x_data).float().unsqueeze(-1)
y_data_torch = torch.from_numpy(y_data).float().unsqueeze(-1)

n_epoch = 100000
loss_list = []
for epoch in tqdm(range(n_epoch)):
    loss = svi.step(x_data_torch,y_data_torch,h1,h2)
    loss_list.append(loss)

I made predictions with the trained model using the following code.

predictive = Predictive(model, guide=guide, num_samples=500)

x_new = torch.linspace(-2.0, 2.0, 1000).unsqueeze(-1)
y_pred_samples = predictive(x_new, None, h1, h2)['Y'] 
y_pred_mean = y_pred_samples.mean(axis=0)
percentiles = np.percentile(y_pred_samples.squeeze(-1), [5.0, 95.0], axis=0)

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(x_data, y_data, 'o', markersize=3, label='data')
ax.plot(x_linspace, y_linspace, label='true_func')
ax.plot(x_new, y_pred_mean, label='mean')
ax.fill_between(x_new.squeeze(-1), percentiles[0, :], percentiles[1, :],
                  alpha=0.5, label='90percentile', color='orange')

ax.set_xlabel(r'')
ax.set_ylabel(r'')
ax.set_ylabel(r'')
ax.set_ylim(-13 ,13)
ax.legend();