I’m very new to numpyro and I’ve been trying to implement a simple linear regression using numpyro, just to get an understanding of how it works. Here’s a snippet of the code I’ve written based on the documentation provided.
import jax
from jax import random
from jax import grad, jit
# jax numpy is imported as np NOT numpy
import jax.numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMC
import scipy.sparse
import sys
import numpy
# Choose the "true" parameters.
m_true = -0.9594
b_true = 4.294
f_true = 1.0
# initializing parameters
N = 50; J = 1
X = random.normal(random.PRNGKey(seed=123), (N, J))
weight = np.array([m_true])
error = 0.1 * random.normal(random.PRNGKey(234), (N, ))
y_obs = f_true * (X @ weight + b_true) + error
y = y_obs.reshape((N, 1))
def model(X, y=None):
ndims = np.shape(X)[-1]
ws = numpyro.sample('betas', dist.Normal(0.0, 10.0*np.ones(ndims)))
b = numpyro.sample('b', dist.Normal(0.0, 10.0))
sigma = numpyro.sample('sigma', dist.Uniform(0.0, 10.0))
f = numpyro.sample('f', dist.Normal(0.0, 10.0))
mu = f * (X @ ws + b)
return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
nuts_kernel = NUTS(model)
num_warmup, num_samples = 500, 1500
mcmc = MCMC(nuts_kernel, num_warmup, num_samples, num_chains=1)
mcmc.run(random.PRNGKey(4), X, y = y_obs)
mcmc.print_summary()
The model is of the form y = f*(m*x + b). My question is that how is the likelihood function defined? The function “model”, I presume, defines probabilistic variables for all the model parameters. Does the return statement in the model correspond to the posterior probability distribution? How does numpyro know what the standard errors in my data are? How do I specify that?