Hi, I’m implementing the linear model described in chapter 20 in the book Doing Bayesian Data Analysis (2nd edition) by John K. Kruschke using Numpyro. The same model is implemented using PyMC3 here.
My current implementation using Numpyro is very slow, it takes 10 minutes to finish sampling 20,000 samples. Is there anything wrong with my model? And how can I make it faster?
import jax.numpy as jnp
import jax.random as random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import pandas as pd
def gammaDistFromModeStd(mode, std):
std_squared = std**2
rate = (mode + jnp.sqrt(mode**2 + 4 * std_squared)) / (2 * std_squared)
shape = 1 + mode * rate
return dist.Gamma(shape, rate)
def multi_nominal_predictors(y: jnp.ndarray, grp: jnp.ndarray, nb_groups: 'list[int]'):
"""
Bayesian model as described in Chapter 20, Section 20.2, Figure 20.2
Parameters
----------
y: jnp.ndarray
Metric predicted variable.
grp: jnp.ndarray
Nominal predictors.
nb_groups: list[int]
List of the number of unique groups in each column of `grp`.
"""
assert y.shape[0] == grp.shape[0]
assert y.ndim == 1 and grp.ndim == 2
assert grp.shape[1] == len(nb_groups) == 2
nb_obs = y.shape[0]
# Predicted statistics.
y_mean = jnp.mean(y)
y_sd = jnp.std(y)
# Priors for the intercept.
a0 = numpyro.sample('a0', dist.Normal(y_mean, y_sd * 5))
# Priors for coefficients associated with the first factor.
a1_sigma = numpyro.sample(
'a1_sigma', gammaDistFromModeStd(y_sd / 2, y_sd * 2))
a1 = numpyro.sample(
'a1', dist.Normal(0, a1_sigma).expand((nb_groups[0], )))
# Priors for coefficients associated with the second factor.
a2_sigma = numpyro.sample(
'a2_sigma', gammaDistFromModeStd(y_sd / 2, y_sd * 2))
a2 = numpyro.sample(
'a2', dist.Normal(0, a2_sigma).expand((nb_groups[1], )))
# Priors for coefficients associated with the interaction
# between the first and the second factor.
a12_sigma = numpyro.sample(
'a12_sigma', gammaDistFromModeStd(y_sd / 2, y_sd * 2))
a12 = numpyro.sample(
'a12', dist.Normal(0, a12_sigma).expand(tuple(nb_groups)))
# Priors for y_sigma.
y_sigma = numpyro.sample('y_sigma', dist.Uniform(y_sd / 100, y_sd * 10))
# Observations.
with numpyro.plate('obs', nb_obs) as idx:
g1 = grp[idx, 0]
g2 = grp[idx, 1]
mean = a0 + a1[g1] + a2[g2] + a12[g1, g2]
numpyro.sample('y', dist.Normal(mean, y_sigma), obs=y[idx])
salary_df = pd.read_csv('Salary.csv')
salary_df['Org'] = salary_df['Org'].astype('category')
salary_df['Pos'] = salary_df['Pos'].astype('category')
kernel = NUTS(multi_nominal_predictors)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=20000, num_chains=1)
mcmc.run(
random.PRNGKey(0),
y=jnp.array(salary_df['Salary'].values),
grp=jnp.concatenate([salary_df[c].cat.codes.values[..., None] for c in ['Org', 'Pos']], axis=1),
nb_groups=(salary_df['Org'].cat.categories.size, salary_df['Pos'].cat.categories.size),
)
mcmc.print_summary()
And here is the link to the data: Salary.csv
Thank you.