MCMC sampling breaks down for Poisson regression and high means

Hi, I am trying to fit a Poisson regression model with MCMC sampling.
My problem is that the sampling breaks down for high values of the rate of the Poisson process. Take the minimal reproducer below: if I set the mean value of my (dummy) data to 400, everything is fine and I can fit the model (top plots below). However, if I increase the value to e.g. 1000, the sampling breaks down completely, see bottom plots below. Any idea what I am doing wrong?

import numpyro as nmp
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
from jax import random
import jax.numpy as jnp
from numpy.random import default_rng
import arviz as az

# Create dummy data
rng = default_rng()
mean = 400
samples = rng.poisson(mean, size=1000)

# Define model
def model(data):
    mu = nmp.sample('mu', dist.Normal(8., 2.))
    lam = jnp.exp(mu)
    y = nmp.sample("y", dist.Poisson(lam), 
                   obs=data)

# Fit model
nuts_kernel = NUTS(model, target_accept_prob=0.99)
mcmc = MCMC(nuts_kernel, num_samples=4000, num_warmup=2000, num_chains=1)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, samples)

az.plot_trace(mcmc)

have you tried using 64 bit precision? try enable_x64

2 Likes

Thanks a lot, that solved the issue.