Moving the predict_fn
outside the loop solved the problem!
Thank you so much for all your help (and your very quick answers )
Just leaving a working example for others who want to iterate over datasets without memory leaks:
mcmc = MCMC(
NUTS(model),
progress_bar=False,
jit_model_args=True,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=1,
chain_method="sequential")
@jit
def predict_fn_jit(rng_key, samples, *args, **kwargs):
return Predictive(model, samples)(rng_key, *args, **kwargs)
N_runs = 100
for i in tqdm(range(N_runs)):
rng_key, rng_key_ = random.split(rng_key)
data, data_no_y = get_data()
mcmc.run(rng_key_, **data)
posterior_samples = mcmc.get_samples()
predictions_jit = predict_fn_jit(rng_key_, posterior_samples, **data_no_y)["obs"]
# mcmc._warmup_state = mcmc._last_state