Hi all,
Link to the googlecolab:
I am trying to use the numpyro sampling on top of the model defined by flax.

First I loaded the Boston Housing Data. Defined the feedforward Flax model
FlaxMLP
and trained it usingoptax
. The only input argument to Flax model is layers. 
Next I defined the
numpyro
model as follows:
def BayesianMLP(x, y = None):
# Wrap Flax obeject to numpyro
net = random_flax_module(name = 'bnn',
nn_module= FlaxMLP(args.layers),
prior = dist.Normal(0, 1),
input_shape=(x.shape[1]))
with numpyro.plate("data", x.shape[0]):
# Define the distribution on the observation
numpyro.sample("yhat", dist.Normal(loc = net(x)).to_event(1), obs=y)
 And finally did a SVI inference:
guide = autoguide.AutoNormal(BayesianMLP)
# Define the stochastic variational inference optimizer using optax
optimizer_svi = getattr(optax, args.optax_optimizer)(learning_rate=args.step_size)
# Loss function is the ELBO
svi = SVI(BayesianMLP, guide, optimizer_svi, loss=Trace_ELBO())
svi_results = svi.run(
rng_key=rng_key,
num_steps=args.num_samp_svi,
x=X_train, y=Y_train,
progress_bar=args.progress_bar)
However, the mean value of predictions are far off and with MCMC seems every chain is diverging.
Any help would be appreciated.