Slow sampling with Numpyro in linear regression model with nominal predictors

Hi, I’m implementing the linear model described in chapter 20 in the book Doing Bayesian Data Analysis (2nd edition) by John K. Kruschke using Numpyro. The same model is implemented using PyMC3 here.

My current implementation using Numpyro is very slow, it takes 10 minutes to finish sampling 20,000 samples. Is there anything wrong with my model? And how can I make it faster?

import jax.numpy as jnp
import jax.random as random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import pandas as pd


def gammaDistFromModeStd(mode, std):
    std_squared = std**2
    rate = (mode + jnp.sqrt(mode**2 + 4 * std_squared)) / (2 * std_squared)
    shape = 1 + mode * rate

    return dist.Gamma(shape, rate)


def multi_nominal_predictors(y: jnp.ndarray, grp: jnp.ndarray, nb_groups: 'list[int]'):
    """
    Bayesian model as described in Chapter 20, Section 20.2, Figure 20.2

    Parameters
    ----------
    y: jnp.ndarray
        Metric predicted variable.
    grp: jnp.ndarray
        Nominal predictors.
    nb_groups: list[int]
        List of the number of unique groups in each column of `grp`.
    """
    assert y.shape[0] == grp.shape[0]
    assert y.ndim == 1 and grp.ndim == 2
    assert grp.shape[1] == len(nb_groups) == 2

    nb_obs = y.shape[0]

    # Predicted statistics.
    y_mean = jnp.mean(y)
    y_sd = jnp.std(y)

    # Priors for the intercept.
    a0 = numpyro.sample('a0', dist.Normal(y_mean, y_sd * 5))

    # Priors for coefficients associated with the first factor.
    a1_sigma = numpyro.sample(
        'a1_sigma', gammaDistFromModeStd(y_sd / 2, y_sd * 2))
    a1 = numpyro.sample(
        'a1', dist.Normal(0, a1_sigma).expand((nb_groups[0], )))

    # Priors for coefficients associated with the second factor.
    a2_sigma = numpyro.sample(
        'a2_sigma', gammaDistFromModeStd(y_sd / 2, y_sd * 2))
    a2 = numpyro.sample(
        'a2', dist.Normal(0, a2_sigma).expand((nb_groups[1], )))

    # Priors for coefficients associated with the interaction
    # between the first and the second factor.
    a12_sigma = numpyro.sample(
        'a12_sigma', gammaDistFromModeStd(y_sd / 2, y_sd * 2))
    a12 = numpyro.sample(
        'a12', dist.Normal(0, a12_sigma).expand(tuple(nb_groups)))

    # Priors for y_sigma.
    y_sigma = numpyro.sample('y_sigma', dist.Uniform(y_sd / 100, y_sd * 10))

    # Observations.
    with numpyro.plate('obs', nb_obs) as idx:
        g1 = grp[idx, 0]
        g2 = grp[idx, 1]
        mean = a0 + a1[g1] + a2[g2] + a12[g1, g2]
        numpyro.sample('y', dist.Normal(mean, y_sigma), obs=y[idx])


salary_df = pd.read_csv('Salary.csv')
salary_df['Org'] = salary_df['Org'].astype('category')
salary_df['Pos'] = salary_df['Pos'].astype('category')

kernel = NUTS(multi_nominal_predictors)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=20000, num_chains=1)
mcmc.run(
    random.PRNGKey(0),
    y=jnp.array(salary_df['Salary'].values),
    grp=jnp.concatenate([salary_df[c].cat.codes.values[..., None] for c in ['Org', 'Pos']], axis=1),
    nb_groups=(salary_df['Org'].cat.categories.size, salary_df['Pos'].cat.categories.size),
)
mcmc.print_summary()

And here is the link to the data: Salary.csv

Thank you.

hard to say without taking a close look at your model but this tutorial may be helpful; e.g. you might reparameterize a2 and a12

So I reparametrize all the a’s as followed:

a_sigma = numpyro.sample(
        'a_sigma', gammaDistFromModeStd(y_sd / 2, y_sd * 2))
a_ = numpyro.sample('a_', dist.Normal(0, 1))
a = numpyro.deterministic('a', a_ * a_sigma)

The sampling process does speed up, it now only takes around 2 minutes. But I’m curious why does reparametrization improve the sampling speed?