Hm, I tried to compare `Predictive`

with the effect handler function defined in the Bayesian Regression Tutorial. When using the effect handler function, the memory consumption is constant in time compared to the one using `Predictive`

. The rest of code below can be seen at the bottom of this post.

This works (i.e. constant time memory consumption):

```
N_runs = 100
for i in tqdm(range(N_runs)):
rng_key, rng_key_ = random.split(rng_key)
mcmc.run(rng_key_, marriage=marriage, divorce=divorce)
samples_1 = mcmc.get_samples()
predict_fn = vmap(lambda rng_key, samples: predict(rng_key, samples, model_marriage, marriage=marriage))
predictions_1 = predict_fn(random.split(rng_key_, num_samples), samples_1)
```

whereas this doesnâ€™t (i.e. not constant time memory consumption):

```
N_runs = 100
for i in tqdm(range(N_runs)):
rng_key, rng_key_ = random.split(rng_key)
mcmc.run(rng_key_, marriage=marriage, divorce=divorce)
samples_1 = mcmc.get_samples()
predictions = Predictive(model_marriage, samples_1)(rng_key_, marriage=marriage)["obs"]
```

I assume this is not as expected? Also, it seems that the effect handlers are a lot faster than `Predictive`

? `Timeit`

for `Predictive`

was `939 ms Â± 44.1 ms`

compared to `13.8 ms Â± 1.6 ms`

for the effect handlers.

So I could, for now, just use the effect handlers method, however, I am having a bit of trouble getting it to work with my previous datasets mentioned. If I try:

```
predict_fn = vmap(lambda rng_key, samples: predict(rng_key, samples, model, data["z"], data["N"]))
predictions_1 = predict_fn(random.split(rng_key_, num_samples), samples_1)
```

where `model`

and `samples_1`

are from my original dataset, it just runs and runs without getting anywhere, using quite a lot of cpu. I thought I would be able to just use this method on my own dataset as well. Do you happen to see what I miss here? If only this small â€śhackâ€ť could work, I would be more than happy to help with anything related to fixing `Predictive`

if I can

A last note, in the tutorial, in cell [12], where it says:

`df['Mean Predictions'] = jnp.mean(predictions, axis=0) `

shouldnâ€™t it be:

`df['Mean Predictions'] = jnp.mean(predictions_1, axis=0) `

to reflect the new predictions from the effect handlers?

*Rest of the code:

```
def get_data_marriage():
DATASET_URL = (
"https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv"
)
dset = pd.read_csv(DATASET_URL, sep=";")
standardize = lambda x: (x - x.mean()) / x.std()
dset["AgeScaled"] = dset.MedianAgeMarriage.pipe(standardize)
dset["MarriageScaled"] = dset.Marriage.pipe(standardize)
dset["DivorceScaled"] = dset.Divorce.pipe(standardize)
return dset.MarriageScaled.values, dset.DivorceScaled.values
def model_marriage(marriage, divorce=None):
a = numpyro.sample("a", dist.Normal(0.0, 0.2))
bM = numpyro.sample("bM", dist.Normal(0.0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1.0))
mu = a + bM * marriage
numpyro.sample("obs", dist.Normal(mu, sigma), obs=divorce)
def predict(rng_key, post_samples, model, *args, **kwargs):
model = handlers.condition(handlers.seed(model, rng_key), post_samples)
model_trace = handlers.trace(model).get_trace(*args, **kwargs)
return model_trace["obs"]["value"]
rng_key = PRNGKey(0)
num_warmup, num_samples = 1000, 2000
marriage, divorce = get_data_marriage()
mcmc = MCMC(
NUTS(model_marriage),
progress_bar=False,
jit_model_args=True,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=1,
chain_method="sequential",
)
```