Handlers.reparam leads to az.from_numpyro error

Hi there, I just started using numpyro to build a hierarchical model where a group-level parameter is sampled from a circular distribution. I use handler.reparam to set this parameter as CircularReparam().

def hierarchical_model(sub_idx, xx, yy=None):
    # Define hyperpriors mu_m, sigma_m for m 
    mu_m = numpyro.sample('mu_m', dist.Uniform(-numpy.pi, numpy.pi))
    sigma_m = numpyro.sample('sigma_m', dist.HalfNormal(50.))

    # Sample m
    with numpyro.plate("plate_i", n_sub):
        with handlers.reparam(config={'m': CircularReparam()}):
            m = numpyro.sample('m', dist.VonMises(mu_m, sigma_m))

    # Sample estimates
    .....

The model can run without error, but I cannot read the data with az.from_numpyro() because of a NotImplementedError in ~/numpyro/distributions/distribution.py, line 249. The error is gone when I remove the reparam config for m.
Do you maybe have suggestions to solve the problem?

cannot read the data with az.from_numpyro() because of a NotImplementedError

I’m not sure about this error message. Could you provide reproducible code and a better error stack? Also try to update arviz, numpyro to see if the issue is already resolved.

Thanks for your reply:) I have the latest numpyro and arviz installed with python3.8.

I simplified the case as estimating the parameters of a von mises function:

# TEST example
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro import handlers
from numpyro.infer.reparam import CircularReparam
from jax import random, numpy
import arviz as az
from scipy import special


def simple_model(xx, yy_obs=None):

   with handlers.reparam(config={'m': CircularReparam()}):
       m = numpyro.sample('m', dist.VonMises(0, 0.5))
   k = numpyro.sample('k', dist.HalfCauchy(2))
           
   # Sampling error
   err = numpyro.sample('err', dist.HalfNormal(100))

   # Define the mean of estimated values yy as a von Mises function of xx
   yy = 10 * numpy.exp(k * numpy.cos(xx - m)) / (2 * numpy.pi * numpy.i0(k))

   # Sample estimates from Von Mises
   numpyro.sample('obs', dist.VonMises(yy, err), obs=yy_obs)

# Generate fake data
xx = np.linspace(-np.pi, np.pi, num=100)
m = 0.5
k = 1.5
yy_obs = 10 * np.exp(k * np.cos(xx - m)) / (2 * np.pi * special.i0(k)) + np.random.uniform(0.01, 0.03, 100)

# Visualize data
sns.lineplot(x=xx, y=yy_obs) 

# Run model
nuts_kernel = NUTS(simple_model)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
rng_key = random.PRNGKey(0)

mcmc.run(rng_key, xx=xx, yy_obs=yy_obs)
mcmc.print_summary()

idata = az.from_numpyro(mcmc)
az.plot_trace(idata, compact=True)

Then idata = az.from_numpyro(mcmc) raised errors shown below:

When I remove with handlers.reparam(config={'m': CircularReparam()}), the error will be gone. Would it because I incorrectly use the reparam function?

I just tested your code in colab with the latest versions

!pip install -Uq numpyro arviz

and it worked. Could you double check the versions?

Wow you are right! I completely forgot to update arviz in that virtual environment… Thank you so much for saving my day :slight_smile: