Unexpectedly different outcomes when initializing via NUTS or MCMC

I’m finding that setting initial parameter values through NUTS or through MCMC gives different results, even though they should be the same.

I’ve checked that the parameter values are initialized to be the same from the first sample taken in a run either using

nuts_kernel = NUTS(model,
                       dense_mass=dense_mass,
                       max_tree_depth=6,
                       init_strategy=init_to_value...
sampler = MCMC(nuts_kernel,
                   num_warmup=10,
                   num_samples=10,
                   jit_model_args=True)
key, subkey = random.split(key)
sampler.warmup(subkey, X, collect_warmup=True)

or with

nuts_kernel = NUTS(model,
                       dense_mass=dense_mass,
                       max_tree_depth=6)
sampler = MCMC(nuts_kernel,
                   num_warmup=10,
                   num_samples=10,
                   jit_model_args=True)
key, subkey = random.split(key)
sampler.warmup(subkey, X, collect_warmup=True, init_params=...)

In the first case, I’m using the init_to_value approach as provided for in the NUTS interface. In the latter, I’m giving the initial parameter values directly to the MCMC class instance.

The chains are exactly the same for the first few steps but then they diverge.

I’ve set the random keys to be the same, but I suppose this is not guaranteed to make them the same without looking carefully at the source. However, I see that the statistics of the samples returned from the two different initialization strategies is completely different (Rhat values are different by an order of magnitude).

I did look quickly at the source, and there no obvious reason from the initialization steps that these should lead to different results.

Would appreciate insights. Thanks.

if the first few steps are similar, it is likely that numerical computation leads to different chains. many conditions can affect the nuts trajectories like stopping condition etc.

I went through the source more carefully, and I realize that the way that the parameters are given to MCMC are slightly different from how they’re given to NUTS.

I had a deterministic component of the parameters that I was passing to MCMC, but if we use NUTS to initialize the parameters, this deterministic component is not given. Once I removed it from the dict passed into MCMC.run init_params argument, the runs were almost exactly the same.

As for the different R-hat values, that was a mistake on my part in how the separate chains were separated. As a side note, it might be useful to have the sampler be able to return the chains separately instead of concatenated onto one another.

it might be useful to have the sampler be able to return the chains separately instead of concatenated onto one another

do you mean to set group_by_chain=True ? https://num.pyro.ai/en/stable/mcmc.html#numpyro.infer.mcmc.MCMC.get_samples

Ah, yes…sorry I missed the obvious.