 # Slow sampling with Numpyro in linear regression model with nominal predictors

Hi, I’m implementing the linear model described in chapter 20 in the book Doing Bayesian Data Analysis (2nd edition) by John K. Kruschke using Numpyro. The same model is implemented using PyMC3 here.

My current implementation using Numpyro is very slow, it takes 10 minutes to finish sampling 20,000 samples. Is there anything wrong with my model? And how can I make it faster?

``````import jax.numpy as jnp
import jax.random as random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import pandas as pd

std_squared = std**2
rate = (mode + jnp.sqrt(mode**2 + 4 * std_squared)) / (2 * std_squared)
shape = 1 + mode * rate

return dist.Gamma(shape, rate)

def multi_nominal_predictors(y: jnp.ndarray, grp: jnp.ndarray, nb_groups: 'list[int]'):
"""
Bayesian model as described in Chapter 20, Section 20.2, Figure 20.2

Parameters
----------
y: jnp.ndarray
Metric predicted variable.
grp: jnp.ndarray
Nominal predictors.
nb_groups: list[int]
List of the number of unique groups in each column of `grp`.
"""
assert y.shape == grp.shape
assert y.ndim == 1 and grp.ndim == 2
assert grp.shape == len(nb_groups) == 2

nb_obs = y.shape

# Predicted statistics.
y_mean = jnp.mean(y)
y_sd = jnp.std(y)

# Priors for the intercept.
a0 = numpyro.sample('a0', dist.Normal(y_mean, y_sd * 5))

# Priors for coefficients associated with the first factor.
a1_sigma = numpyro.sample(
'a1_sigma', gammaDistFromModeStd(y_sd / 2, y_sd * 2))
a1 = numpyro.sample(
'a1', dist.Normal(0, a1_sigma).expand((nb_groups, )))

# Priors for coefficients associated with the second factor.
a2_sigma = numpyro.sample(
'a2_sigma', gammaDistFromModeStd(y_sd / 2, y_sd * 2))
a2 = numpyro.sample(
'a2', dist.Normal(0, a2_sigma).expand((nb_groups, )))

# Priors for coefficients associated with the interaction
# between the first and the second factor.
a12_sigma = numpyro.sample(
'a12_sigma', gammaDistFromModeStd(y_sd / 2, y_sd * 2))
a12 = numpyro.sample(
'a12', dist.Normal(0, a12_sigma).expand(tuple(nb_groups)))

# Priors for y_sigma.
y_sigma = numpyro.sample('y_sigma', dist.Uniform(y_sd / 100, y_sd * 10))

# Observations.
with numpyro.plate('obs', nb_obs) as idx:
g1 = grp[idx, 0]
g2 = grp[idx, 1]
mean = a0 + a1[g1] + a2[g2] + a12[g1, g2]
numpyro.sample('y', dist.Normal(mean, y_sigma), obs=y[idx])

salary_df['Org'] = salary_df['Org'].astype('category')
salary_df['Pos'] = salary_df['Pos'].astype('category')

kernel = NUTS(multi_nominal_predictors)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=20000, num_chains=1)
mcmc.run(
random.PRNGKey(0),
y=jnp.array(salary_df['Salary'].values),
grp=jnp.concatenate([salary_df[c].cat.codes.values[..., None] for c in ['Org', 'Pos']], axis=1),
nb_groups=(salary_df['Org'].cat.categories.size, salary_df['Pos'].cat.categories.size),
)
mcmc.print_summary()
``````

And here is the link to the data: Salary.csv

Thank you.

hard to say without taking a close look at your model but this tutorial may be helpful; e.g. you might reparameterize `a2` and `a12`

So I reparametrize all the a’s as followed:

``````a_sigma = numpyro.sample(
'a_sigma', gammaDistFromModeStd(y_sd / 2, y_sd * 2))
a_ = numpyro.sample('a_', dist.Normal(0, 1))
a = numpyro.deterministic('a', a_ * a_sigma)
``````

The sampling process does speed up, it now only takes around 2 minutes. But I’m curious why does reparametrization improve the sampling speed?