I am trying to implement the following model in numpyro (from book statistical rethinking, page 547, model 16.5)
Given below is mys attempt to verbatim model it:
import jax.numpy as jnp
import jax as jx
from jax.experimental.ode import odeint
import numpyro as npr
import numpyro.distributions as dist
def model(N, pelts=None):
sigma_H = npr.sample("sigma_h", dist.Exponential(1.0))
sigma_L = npr.sample("sigma_l", dist.Exponential(1.0))
pH = npr.sample("pH", dist.Beta(40,200))
pL = npr.sample("pL", dist.Beta(40,200))
bH = npr.sample("bH", dist.HalfNormal(1, 0.5))
bL = npr.sample("bL", dist.HalfNormal(0.05, 0.05))
mH = npr.sample("mH", dist.HalfNormal(0.05, 0.05))
mL = npr.sample("mL", dist.HalfNormal(1, 0.5))
H1 = npr.sample("H1", dist.LogNormal(jnp.log(10), 1))
L1 = npr.sample("L1", dist.LogNormal(jnp.log(10), 1))
times_measured = jnp.arange(float(N))
pop = npr.deterministic("pop", odeint(dpop_dt,
jnp.array([L1, H1]),
times_measured,
jnp.array([bH, mH, mL, bL]),
rtol=1e-5, atol=1e-5, mxstep=500))
if pelts is None:
lt = npr.sample("lt", dist.LogNormal(jnp.log(pL*pop[:,0]),sigma_L))
ht = npr.sample("ht", dist.LogNormal(jnp.log(pH*pop[:,1]),sigma_H))
else:
lt = npr.sample("lt", dist.LogNormal(jnp.log(pL*pop[:,0]),sigma_L), obs=pelts[:,0])
ht = npr.sample("ht", dist.LogNormal(jnp.log(pH*pop[:,1]),sigma_H), obs=pelts[:,1])
However when I sample it as
m16_5 = npr.infer.MCMC(npr.infer.NUTS(model, target_accept_prob=0.95), num_chains=3, num_samples=1000, num_warmup=1000)
m16_5.run(jx.random.PRNGKey(0), **dat_list)
I run into following problems:
- Chain 1 takes about 5 min to draw 1000 samples, while chain 2 and 3 finishes in less then a minute on my laptop.
- Effective number of parameters are NaN for several parameters (see below)
- Predictive posteriors are all wrong (too small and too large)
Here is posterior summary
mean std median 5.0% 95.0% n_eff r_hat
H1 127.72 92.99 168.79 1.35 223.45 1.59 3.95
L1 23.79 15.73 30.29 2.58 40.09 nan 3.62
bH 1.66 1.58 0.58 0.48 3.90 1.50 33.23
bL 0.35 0.49 0.00 0.00 1.05 1.50 731.17
mH 0.07 0.09 0.01 0.00 0.19 1.50 118.64
mL 0.63 0.26 0.75 0.28 0.90 nan 3.74
pH 0.19 0.02 0.20 0.14 0.21 nan 1.10
pL 0.31 0.19 0.19 0.15 0.57 1.51 10.20
sigma_h 0.52 0.39 0.27 0.21 1.07 1.51 11.60
sigma_l 0.50 0.35 0.28 0.21 0.99 1.51 10.40
Number of divergences: 0
What are the mistakes in my attempt?
Here is a successful attempt by more accomplished people! Chapter 16. Generalized Linear Madness | Statistical Rethinking (2nd ed.) with NumPyro
But it is significantly modified, without explanation of why.