Simple Bayesian Neural Network throws errors

Dear pyromaniacs,

I am working on quite a hairy non-linear regression problem where I would like to sample posterior predictive distributions. I know from MLE studies that the likelihood landscape is multimodal and heteroskedastic. MLE in my problem is computationally expensive and I aim to analyze many samples, so I opted to turn my problem into a regression where I simulate a lot of data and train a non-linear probabilistic regression model.

As a low-dimensional example, I have created some data below where the top plot is the dependent variable (called ratio) and the bottom plot shows data for three independent variables (fraction: X+0, X+1, X+2). I drew the posterior predictive distribution I expect in the top plot in red given the yellow points in the bottom plot as test data. The x-axis is just an indexing function that makes the plots look nice.

For reference, my real-life problem has 5 dependent variables (all domain [0, 1)) and around 70 independent variables. Subsets of the independent variables are probability mass functions, as is also true for the subset X+0, X+1, X+2 in the figure (sum to 1, domain [0,1]). As you can appreciate, the dimensions of this problem are small compared to image classification or other typical Neural Network tasks. Therefore I believe that some sort of Bayesian Neural Network (BNN) in pyro will suit my purposes just fine.

However, despite reading the tutorials on Bayesian linear regression, I struggle to construct a BNN in pyro. I found this blog post: Experimenting with Pyro's hidden native support for Bayesian Neural Networks but that does not work anymore due to an issue with AttributeError: 'HiddenLayer' object has no attribute '_batch_shape'. Also, the code strikes me as verbose (duplicity between model and guide) and since I imagine BNNs to be a common use case for pyro, I was wondering whether there wouldn’t be a shorter way of achieving it. I know that variational inference with independent Gaussian variational distributions over weights/ biases can never produce a multi-modal posterior, thus I am especially interested in Boosting Black Box Variational Inference — Pyro Tutorials 1.8.4 documentation or some form of exact inference using NUTS or some other MCMC sampler.

In the code below DummyData is a torch DataSet that yields a [BATCH_SIZE x 3] tensor of independent variables (shown in bottom plot) and a [BATCH_SIZE x 1] tensor with the dependent variables. As a first test, I want to construct a NN with one hidden Bayesian layer over which I would like to perform variational inference with independent Gaussian variational distributions for weights and biases. The other layers should just be optimized.

from pyro.infer import SVI, Trace_ELBO
from pyro.nn import PyroSample
from pyro.nn.module import PyroModule
from pyro.infer.autoguide import AutoDiagonalNormal
from torch import nn
import pyro
import pyro.distributions as dist
from torch.utils.data import DataLoader, Dataset
from sumoflux.estimate.pai import DummyData
from tqdm import tqdm

model = PyroModule[nn.Sequential](
    PyroModule[nn.Linear](3, 3),
    PyroModule[nn.Tanh](),
    PyroModule[nn.Linear](3, 3),
    PyroModule[nn.Tanh](),
    PyroModule[nn.Linear](3, 1),
)
DEVICE = 'cpu'
BATCH_SIZE = 64
EPOCHS = 100

model[2].weight = PyroSample(prior=dist.Normal(0, 1).expand(model[2].weight.shape).to_event(1))
model[2].bias = PyroSample(prior=dist.Normal(0, 1).expand(model[2].bias.shape).to_event(1))

guide = AutoDiagonalNormal(model)
adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())

losses = []
train = DummyData()
dl = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
for _ in tqdm(range(EPOCHS)):
    for data in dl:
        idx, x, y = data
        x, y = x.to(DEVICE), y.to(DEVICE)
        loss = svi.step(x, y)

This throws the following error:

...
  File "C:\Miniconda3\envs\pta\lib\site-packages\pyro\nn\module.py", line 426, in __call__
    return super().__call__(*args, **kwargs)
  File "C:\Miniconda3\envs\pta\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given

Why is this? In Bayesian Regression - Introduction (Part 1) — Pyro Tutorials 1.8.4 documentation example passing x_data and y_data (code block 11) works as expected. Also, does svi step do all the computation and clearing of gradients automatically? Also, how would I add dropout to this NN (possibly for [1506.02142] Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning)?