Parallelising Numpyro

Moving the predict_fn outside the loop solved the problem!

Thank you so much for all your help (and your very quick answers :smiley: )

Just leaving a working example for others who want to iterate over datasets without memory leaks:

mcmc = MCMC(
    NUTS(model),
    progress_bar=False,
    jit_model_args=True,
    num_warmup=num_warmup,
    num_samples=num_samples,
    num_chains=1,
    chain_method="sequential")

@jit
def predict_fn_jit(rng_key, samples, *args, **kwargs):
    return Predictive(model, samples)(rng_key, *args, **kwargs)

N_runs = 100
for i in tqdm(range(N_runs)):
    rng_key, rng_key_ = random.split(rng_key)
    data, data_no_y = get_data()
    mcmc.run(rng_key_, **data)
    posterior_samples = mcmc.get_samples()
    predictions_jit = predict_fn_jit(rng_key_, posterior_samples, **data_no_y)["obs"]
    # mcmc._warmup_state = mcmc._last_state
2 Likes