Different number of effective posterior samples when changing name of sample sites

Hi devs and community,

I’m getting very different n_eff values in NUTS mcmc.print_summary() (although the parameter estimates are same) when changing name of the sample site for some random variables in my NumPyro model (in numpyro.sample('some_name', dist))

I was not able to reproduce this with a sample data and I won’t be able to share my data.

Is this a known issue? Could you please suggest what could be happening?

This is interesting. Probably this is related to the order when jax flattens dictionaries. Maybe that order is non-deterministic. Could you make an issue on github for this? Maybe it’s better to make a custom version of ravel_pytree that preserves orders of random variables in a model.

This jax issue is probably the reason jax.tree_utils do not keep dict key order · Issue #4085 · google/jax · GitHub