I want to run Bayesian linear regression with PyroModule in numpyro but I failed. Do you know how to use PyroModule or like this one in numpyro?
from pyro.nn import PyroSample
from torch import nn
from pyro.nn import PyroModule
import numpyro
import jax.numpy as np
from jax import random, vmap
from jax.scipy.special import logsumexp
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS
assert issubclass(PyroModule[nn.Linear], nn.Linear)
assert issubclass(PyroModule[nn.Linear], PyroModule)
class BayesianRegression(PyroModule):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = PyroModule[nn.Linear](in_features, out_features)
self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))
def forward(self, x, y=None):
sigma = numpyro.sample("sigma", dist.Uniform(0., 10.))
mean = self.linear(x).squeeze(-1)
with numpyro.plate("data", x.shape[0]):
obs = numpyro.sample("obs", dist.Normal(mean, sigma), obs=y)
return mean
model = BayesianRegression(3, 1)
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, trainX,trainY)
print(mcmc.summary())