Flax with numpyro

Hi all,

Link to the google-colab:

I am trying to use the numpyro sampling on top of the model defined by flax.

  • First I loaded the Boston Housing Data. Defined the feed-forward Flax model FlaxMLP and trained it using optax. The only input argument to Flax model is layers.

  • Next I defined the numpyro model as follows:

def BayesianMLP(x, y = None):
    # Wrap Flax obeject to numpyro
    net = random_flax_module(name = 'bnn', 
                             nn_module= FlaxMLP(args.layers), 
                             prior = dist.Normal(0, 1),
    with numpyro.plate("data", x.shape[0]):
        # Define the distribution on the observation
        numpyro.sample("yhat", dist.Normal(loc = net(x)).to_event(1), obs=y)
  • And finally did a SVI inference:
    guide = autoguide.AutoNormal(BayesianMLP)
    # Define the stochastic variational inference optimizer using optax
    optimizer_svi = getattr(optax, args.optax_optimizer)(learning_rate=args.step_size)
    # Loss function is the ELBO
    svi = SVI(BayesianMLP, guide, optimizer_svi, loss=Trace_ELBO())
    svi_results = svi.run(
        x=X_train, y=Y_train,

However, the mean value of predictions are far off and with MCMC seems every chain is diverging.
Any help would be appreciated.

how many parameters does your neural network have? naive approaches to bayesian neural networks, especially for large networks, can fail spectacularly.
we generally suggest using tyxe for bayesian neural networks in pyro

Thanks, it has 116 parameters.

The example I provided is the simplified version of the real example which I am working on. I just made a working script with similar behavior to ask the community.
I chose to use Flax and numpyro to leverage the consistency between representation of parameters and arrays in Jax, Optax, Flax and NumPyro.

it’s hard to say what’s going on but you might try lowering the learning rate and also trying shallower neural networks and see if you can get that to work before making things harder