Parallelising Numpyro

Hi and thanks again for a great package, I really like using Numpyro!

I now have a setup working (see also this) where I am fitting a BetaBinomial model to some data and getting great results. I now have to fit multiple datasets (around 10,000 small ones) similar to the one I have tested on until now. What is the recommended approach for parallelising this? I have a server available with 40 cores and 200GB RAM.

Since the compilation time is the major time hurdle (~15s for compilation vs. ~3s for sampling) I would like to take advantage of jit compilation using jit_model_args=True in the MCMC.
As far as I’ve understood, the MCMC object (compiled) cannot be pickled and thus not used in the individual processes or is there any update on this? Is there any way of easily parallelising without the large compilation time for each run?

Thanks a lot!

Cheers,
Christian

1 Like

I think you can make a wrapper like run_inference and leverage joblib to distribute the tasks to each CPU.

On each CPU, if your datasets have the same shape, you can use jit_model_args=True. In case your datasets has different shapes and there is no local latent variable (e.g. the model in the previous thread works because it only contains global latent variables), you can pad the data and use mask to mask out the padded part in the model. This way, jit_model_args will work.

You can optimize further by grouping datasets with similar shapes to each task, e.g. if you have datasets A, B, C, D whose shapes are 7, 5, 3, 9, then you can run B, C on 1 CPU and A, D on 1 CPU.

1 Like

Hi again and sorry for the long wait for my reply!

I have tried looking into joblib and run_inference and I am already trying to take advantage of jit_model_args=True. All of my many datasets share exactly the same shape (and dtype) – it is only the numbers that are different – and thus no padding or masking is needed. I just want to fit as many of these datasets as possible.

As thus, it should be a quite easy problem. However, I run into memory leaks (it seems).

I have my entire data set consisting of, say, 20 smaller data sets, which are to be fitted individually. All of them are treated as independent. Due to the compilation time being much larger than sampling time, I want to create as few child workers as possible (while still maxing out all the cores), i.e. instead of splitting the 20 data sets up into 10 chunks of 2, I would rather split it into 4 chunks of 5 (assuming having four cores available).

My code for now is the following:

def get_data():

    N = jnp.array([7642688, 7609177, 8992872, 8679915, 8877887, 8669401])
    y = jnp.array([2036195, 745632, 279947, 200865, 106383, 150621])
    z = jnp.arange(1, len(N) + 1)
    data = {"z": z, "y": y, "N": N}
    return data

def model(z, N, y=None):
    q = numpyro.sample("q", dist.Beta(2, 3))  # mean = 0.4, shape = 5
    A = numpyro.sample("A", dist.Beta(2, 3))  # mean = 0.4, shape = 5
    c = numpyro.sample("c", dist.Beta(1, 9))  # mean = 0.1, shape = 10
    fz = numpyro.deterministic("fz", A * (1 - q) ** (z - 1) + c)

    phi = numpyro.sample("phi", dist.Exponential(1 / 1000))
    theta = numpyro.deterministic("theta", phi + 2)

    alpha = numpyro.deterministic("alpha", fz * theta)
    beta = numpyro.deterministic("beta", (1 - fz) * theta)

    numpyro.sample("obs", dist.BetaBinomial(alpha, beta, N), obs=y)


def get_results_from_mcmc(mcmc, data):
    # simulate getting some results

    posterior_samples = mcmc.get_samples()

    posterior_predictive = Predictive(model, posterior_samples)(Key(42), data["z"], data["N"])[
        "obs"
    ]
    predictions_fraction = posterior_predictive / data["N"]

    y_mean = jnp.mean(predictions_fraction, axis=0)
    y_hpdi = numpyro.diagnostics.hpdi(predictions_fraction, 0.68)
    ds = az.from_numpyro(mcmc, posterior_predictive={"y_pred": posterior_predictive})

    return y_mean, y_hpdi, ds


def worker(N_runs=10):

    print(f"Worker Process: {current_process()=}", flush=True)

    mcmc = None

    for _ in range(N_runs):

        data = get_data()

        if mcmc is None:
            mcmc = MCMC(
                NUTS(model),
                progress_bar=verbose,
                jit_model_args=True,
                num_warmup=500,
                num_samples=1000,
                num_chains=1,
                chain_method="sequential",
            )

        mcmc.run(Key(42), **data)

        if verbose:
            mcmc.print_summary()

        # simulate getting some results
        res = get_results_from_mcmc(mcmc, data)

    return 0


if __name__ == "__main__":

    # results = worker(20)
    results = Parallel(n_jobs=2)(delayed(worker)(N_runs=10) for _ in range(2))
    print(results)

This code works and gives correct results. However, when testing this out, I measure a surprisingly large memory consumption! When using memory_profiler to measure the memory consumption, I find the following:

We see that even after fitting just 20 small data sets in two chunks of ten (with 2 cores), the memory consumption is ~1.3GB – even without saving anything. One runs into memory errors with not even that much of a larger data set.

In comparison when running just results = worker(20), one finds:

which is still a surprisingly large amount, imo.

Do you have any good tips or advice on how to deal with this?
Optimally, one would only fit a single data set at a time (while reusing the same mcmc object and taking advantage of jit_model_args=True), but this is still impossible, right?

Sorry for the long post. Cheers and happy new years!

Hi @wc4, could you test with master branch? We recently merge this PR which addresses a memory leak issue.

Hi again and thanks for the fast answer!

I tried installing the master branch so that now I get:

$ conda list | grep numpyro
numpyro    0.4.1    dev_0    <develop>
$ pip list | grep numpyro
numpyro         0.4.1         /Users/christianmichelsen/software/numpyro
$ cd /Users/christianmichelsen/software/numpyro
$ git pull
Already up to date.

This indicates that I have installed the newest version of the master branch, I hope.

However, I still run into exactly the same problem as before and same graphs as before for both Parallel(n_jobs=2)(delayed(worker)(N_runs=10) for _ in range(2)) and worker(20)

By the way, I am running this on my local cpu and not a gpu.

Hmm, when looking further into the problem it seems that the problem is with the Predictive function.

If I just run:

    for i in range(10):
        data = get_data()
        mcmc.run(PRNGKey(i), **data)
        posterior_samples = mcmc.get_samples()

I get:

which nicely plateaus after a few seconds. However, if I include Predictive:

    for i in range(10):
        data = get_data()
        mcmc.run(PRNGKey(i), **data)
        posterior_samples = mcmc.get_samples()

        predictive = Predictive(model, posterior_samples)
        posterior_predictive = predictive(PRNGKey(i), data["z"], data["N"])["obs"]

I get the following:

Is there a) an easy way to fix this or b) a hack to circumvent Predictive?

Hm, I tried to compare Predictive with the effect handler function defined in the Bayesian Regression Tutorial. When using the effect handler function, the memory consumption is constant in time compared to the one using Predictive. The rest of code below can be seen at the bottom of this post.

This works (i.e. constant time memory consumption):

N_runs = 100
for i in tqdm(range(N_runs)):
    rng_key, rng_key_ = random.split(rng_key)
    mcmc.run(rng_key_, marriage=marriage, divorce=divorce)
    samples_1 = mcmc.get_samples()
    
    predict_fn = vmap(lambda rng_key, samples: predict(rng_key, samples, model_marriage, marriage=marriage))
    predictions_1 = predict_fn(random.split(rng_key_, num_samples), samples_1)

whereas this doesn’t (i.e. not constant time memory consumption):

N_runs = 100
for i in tqdm(range(N_runs)):
    rng_key, rng_key_ = random.split(rng_key)
    mcmc.run(rng_key_, marriage=marriage, divorce=divorce)
    samples_1 = mcmc.get_samples()

    predictions = Predictive(model_marriage, samples_1)(rng_key_, marriage=marriage)["obs"]

I assume this is not as expected? Also, it seems that the effect handlers are a lot faster than Predictive? Timeit for Predictive was 939 ms ± 44.1 ms compared to 13.8 ms ± 1.6 ms for the effect handlers.


So I could, for now, just use the effect handlers method, however, I am having a bit of trouble getting it to work with my previous datasets mentioned. If I try:

predict_fn = vmap(lambda rng_key, samples: predict(rng_key, samples, model, data["z"], data["N"]))
predictions_1 = predict_fn(random.split(rng_key_, num_samples), samples_1)

where model and samples_1 are from my original dataset, it just runs and runs without getting anywhere, using quite a lot of cpu. I thought I would be able to just use this method on my own dataset as well. Do you happen to see what I miss here? If only this small “hack” could work, I would be more than happy to help with anything related to fixing Predictive if I can :slight_smile:


A last note, in the tutorial, in cell [12], where it says:

df['Mean Predictions'] = jnp.mean(predictions, axis=0)

shouldn’t it be:

df['Mean Predictions'] = jnp.mean(predictions_1, axis=0)

to reflect the new predictions from the effect handlers?


*Rest of the code:

def get_data_marriage():
    DATASET_URL = (
        "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv"
    )
    dset = pd.read_csv(DATASET_URL, sep=";")
    standardize = lambda x: (x - x.mean()) / x.std()
    dset["AgeScaled"] = dset.MedianAgeMarriage.pipe(standardize)
    dset["MarriageScaled"] = dset.Marriage.pipe(standardize)
    dset["DivorceScaled"] = dset.Divorce.pipe(standardize)
    return dset.MarriageScaled.values, dset.DivorceScaled.values

def model_marriage(marriage, divorce=None):
    a = numpyro.sample("a", dist.Normal(0.0, 0.2))
    bM = numpyro.sample("bM", dist.Normal(0.0, 0.5))
    sigma = numpyro.sample("sigma", dist.Exponential(1.0))
    mu = a + bM * marriage
    numpyro.sample("obs", dist.Normal(mu, sigma), obs=divorce)

def predict(rng_key, post_samples, model, *args, **kwargs):
    model = handlers.condition(handlers.seed(model, rng_key), post_samples)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    return model_trace["obs"]["value"]


rng_key = PRNGKey(0)
num_warmup, num_samples = 1000, 2000
marriage, divorce = get_data_marriage()

mcmc = MCMC(
    NUTS(model_marriage),
    progress_bar=False,
    jit_model_args=True,
    num_warmup=num_warmup,
    num_samples=num_samples,
    num_chains=1,
    chain_method="sequential",
)

Thanks for pointing it out! Would you like to submit a fix? :slight_smile:

I assume this is not as expected?

Good catch! It is expected under the eyes of JAX devs (probably :D) because different class/method/function calls will be triggered for different runs. A solution here is to make a wrapper like this

@jax.jit
def predict_fn(rng_key, samples, data):
    return Predictive(model, samples)(rng_key, data)

and use this predict_fn in your runs. Could you check if this works?

it seems that the effect handlers are a lot faster than Predictive

I guess you can set parallel=True in Predictive to make it faster? (I’m not sure - if it is not faster, then please creating an issue in github, so we can address it)

1 Like

Yeah, gonna do that tomorrow (European time here, so pretty late :stuck_out_tongue:).


Using the previous method of just Predictive gave me the following memory consumption on my original dataset:

When I use the jitted function you suggested, it is:

So a lot better! However, still with a (smaller) memory leak it seems. Do you have any other ideas? This was run for 100 iterations in the loop, so the data sets shouldn’t be that much larger for this to become a problem eventually. Also, is it correct that the function predict_fn would have to be defined in the loop since it cannot take the model as an input parameter?

Do you have any other ideas?

I’m not sure. Could you check for the size of posterior samples or arviz datasets?

predict_fn would have to be defined in the loop

It is unnecessary because your model is defined at global scope. To take advantage of jit, I think you need to define it outside of the loop.

1 Like

Moving the predict_fn outside the loop solved the problem!

Thank you so much for all your help (and your very quick answers :smiley: )

Just leaving a working example for others who want to iterate over datasets without memory leaks:

mcmc = MCMC(
    NUTS(model),
    progress_bar=False,
    jit_model_args=True,
    num_warmup=num_warmup,
    num_samples=num_samples,
    num_chains=1,
    chain_method="sequential")

@jit
def predict_fn_jit(rng_key, samples, *args, **kwargs):
    return Predictive(model, samples)(rng_key, *args, **kwargs)

N_runs = 100
for i in tqdm(range(N_runs)):
    rng_key, rng_key_ = random.split(rng_key)
    data, data_no_y = get_data()
    mcmc.run(rng_key_, **data)
    posterior_samples = mcmc.get_samples()
    predictions_jit = predict_fn_jit(rng_key_, posterior_samples, **data_no_y)["obs"]
    # mcmc._warmup_state = mcmc._last_state
2 Likes