Converting Stat Rethink Model 16.5 to numpyro

I am trying to implement the following model in numpyro (from book statistical rethinking, page 547, model 16.5)

h_t \sim LogNormal(log(p_HH_T),\sigma_h)\\ l_t \sim LogNormal(log(p_LL_T),\sigma_l)\\ H_1 \sim LogNormal(log10,1)\\ L_1 \sim LogNormal(log10,1)\\ H_T = H_1 + \int H_t(b_H - m_HL_t)dt\\ L_T = L_1 + \int L_t(b_LH_t - m_L)dt\\ \sigma_h,\sigma_l \sim Exponential(1)\\ p_H, p_L \sim Beta(\alpha, \beta)\\ b_H, m_L \sim HalfNormal(1,0.5)\\ b_L, m_H \sim HalfNormal(0.05,0.05)

Given below is mys attempt to verbatim model it:

import jax.numpy as jnp
import jax as jx
from jax.experimental.ode import odeint

import numpyro as npr
import numpyro.distributions as dist

def model(N, pelts=None):
    sigma_H = npr.sample("sigma_h", dist.Exponential(1.0))
    sigma_L = npr.sample("sigma_l", dist.Exponential(1.0))
    pH = npr.sample("pH", dist.Beta(40,200))
    pL = npr.sample("pL", dist.Beta(40,200))
    bH = npr.sample("bH", dist.HalfNormal(1, 0.5))
    bL = npr.sample("bL", dist.HalfNormal(0.05, 0.05))
    mH = npr.sample("mH", dist.HalfNormal(0.05, 0.05))
    mL = npr.sample("mL", dist.HalfNormal(1, 0.5))
    H1 = npr.sample("H1", dist.LogNormal(jnp.log(10), 1))
    L1 = npr.sample("L1", dist.LogNormal(jnp.log(10), 1))
    times_measured = jnp.arange(float(N))
    pop = npr.deterministic("pop", odeint(dpop_dt, 
                                          jnp.array([L1, H1]), 
                                          times_measured, 
                                          jnp.array([bH, mH, mL, bL]),
                                          rtol=1e-5, atol=1e-5, mxstep=500))
    if pelts is None:
        lt = npr.sample("lt", dist.LogNormal(jnp.log(pL*pop[:,0]),sigma_L))
        ht = npr.sample("ht", dist.LogNormal(jnp.log(pH*pop[:,1]),sigma_H))
    else:
        lt = npr.sample("lt", dist.LogNormal(jnp.log(pL*pop[:,0]),sigma_L), obs=pelts[:,0])
        ht = npr.sample("ht", dist.LogNormal(jnp.log(pH*pop[:,1]),sigma_H), obs=pelts[:,1])
        

However when I sample it as

m16_5 = npr.infer.MCMC(npr.infer.NUTS(model, target_accept_prob=0.95), num_chains=3, num_samples=1000, num_warmup=1000)
m16_5.run(jx.random.PRNGKey(0), **dat_list)

I run into following problems:

  1. Chain 1 takes about 5 min to draw 1000 samples, while chain 2 and 3 finishes in less then a minute on my laptop.
  2. Effective number of parameters are NaN for several parameters (see below)
  3. Predictive posteriors are all wrong (too small and too large)

Here is posterior summary

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        H1    127.72     92.99    168.79      1.35    223.45      1.59      3.95
        L1     23.79     15.73     30.29      2.58     40.09       nan      3.62
        bH      1.66      1.58      0.58      0.48      3.90      1.50     33.23
        bL      0.35      0.49      0.00      0.00      1.05      1.50    731.17
        mH      0.07      0.09      0.01      0.00      0.19      1.50    118.64
        mL      0.63      0.26      0.75      0.28      0.90       nan      3.74
        pH      0.19      0.02      0.20      0.14      0.21       nan      1.10
        pL      0.31      0.19      0.19      0.15      0.57      1.51     10.20
   sigma_h      0.52      0.39      0.27      0.21      1.07      1.51     11.60
   sigma_l      0.50      0.35      0.28      0.21      0.99      1.51     10.40

Number of divergences: 0

What are the mistakes in my attempt?

Here is a successful attempt by more accomplished people! Chapter 16. Generalized Linear Madness | Statistical Rethinking (2nd ed.) with NumPyro

But it is significantly modified, without explanation of why.

It seems that two models are quite similar, the corresponding names of the variables are:

  • (bH, bL, mH, mL): theta
  • (pH, pL): p
  • (sigma_H, sigma_L): sigma
  • (H1, L1): pop_init
  • (lt, ht): pelts

One difference that I can see is you are using HalfNormal distribution for theta while the other model uses TruncatedNormal distribution. HalfNormal only has a scale parameter so dist.HalfNormal(1, 0.5) does not make much sense. Probably you wanted TruncatedNormal distribution?