Dealing with breaking change of Predictive in numpyro 0.14.0

I just tried using new numpyro release 0.14.0 and my posterior predictive modelling broke , because of the breaking change:

Breaking change: Predictive will try to avoid recomputing “deterministic” sites if it is provided in posterior_samples. Those deterministic sites are excluded in the previous releases.

I am currently using something like following:

def model(x,y):
    .....
    mu = numpyro.deterministic("mu",  .... < calculate expected value>... )
    numpyro.sample("obs", dist.Normal(mu, residual_sigma), obs=y)

Then, to analyze model prediction, I would use

mcmc = MCMC(NUTS(model),...)  
mcmc.run(rng_key1,real_x,real_y)

simulated_x=jnp.linspace(-1.0, 1.0, 100)
mu = Predictive(mcmc.sampler.model, mcmc.get_samples(),return_sites=["mu"])(rng_key2, simulated_x,None)["mu"]

Where real_x and real_y - are my real observations and simulated_x - just values to show mean expected values.
Prior to the release 0.14.0 this code worked as expected, but with the new breaking changes this code just returns mu that was sampled during inference phase (i.e with too many samples for my purposes).

So, the question - how do I revert to the old behavior of Predictive ?

see numpyro.deterministic static on infer.Predictive · Issue #1772 · pyro-ppl/numpyro · GitHub

ok, samples.pop("mu") until the future release?

2 Likes