How can I use PyroModule or like this in numpyro?

I want to run Bayesian linear regression with PyroModule in numpyro but I failed. Do you know how to use PyroModule or like this one in numpyro?

from pyro.nn import PyroSample
from torch import nn
from pyro.nn import PyroModule

import numpyro

import jax.numpy as np
from jax import random, vmap
from jax.scipy.special import logsumexp
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS

assert issubclass(PyroModule[nn.Linear], nn.Linear)
assert issubclass(PyroModule[nn.Linear], PyroModule)


class BayesianRegression(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))

    def forward(self, x, y=None):
        sigma = numpyro.sample("sigma", dist.Uniform(0., 10.))
        mean = self.linear(x).squeeze(-1)
        with numpyro.plate("data", x.shape[0]):
            obs = numpyro.sample("obs", dist.Normal(mean, sigma), obs=y)
        return mean

model = BayesianRegression(3, 1)

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, trainX,trainY)

print(mcmc.summary())

[From my reply to https://github.com/pyro-ppl/numpyro/issues/558 for future reference]

Due to a functional interface for JAX’s NN module, I think this might be easier to do with NumPyro without resorting to random_module. For instance, if you are using jax.stax to build your NN, you can pass in parameter values directly to the NN’s apply_fn. To make your NN bayesian, you can sample these values from some appropriate prior. Here’s some demo code:

class ModelNN:
    def __init__(self, rng_key, in_features):
        init_fn, apply_fn = stax.Dense(1)
        _, init_params = init_fn(rng_key, (in_features,))
        self.init_params = init_params
        self.apply_fn = apply_fn

    def __call__(self, x, y=None):
        W = numpyro.sample('W', dist.Normal(0., 1.), sample_shape=self.init_params[0].shape)
        b = numpyro.sample('b', dist.Normal(0., 10.), sample_shape=self.init_params[1].shape)
        mean = self.apply_fn((W, b), x).squeeze(-1)
        sigma = numpyro.sample("sigma", dist.HalfNormal(1.))
        with numpyro.plate("M", x.shape[0]):
            y = numpyro.sample("obs", dist.Normal(mean, sigma), obs=y)
        return y, W, b, sigma


model = ModelNN(random.PRNGKey(0), 5)
x = dist.Normal(np.array([1., -1., 2., 3., 1.]), 1.).sample(random.PRNGKey(1), sample_shape=(1000,))
# Generate data
y, W, b, sigma = seed(model, random.PRNGKey(2))(x)
print(W, b, sigma)

# Run inference
nuts = NUTS(model)
mcmc = MCMC(nuts, 500, 500)
mcmc.run(random.PRNGKey(4), x, y)
mcmc.print_summary()

Thank you very much! I will try it