Numpyro out of memory when running SVI in parallel

I’m having issues running Numpyro’s SVI in parallel. I have a dataset of ~billion data points, which I have broken up into batches of ~0.5 million data points. If I run SVI for a single batch, the whole thing uses < 3GB memory. But if I run two batches in parallel, the process runs out of memory, on a machine with 128 GB RAM.

My model and guide are defined in this post, and I’m using the following methods to run things in parallel

def run_inference(data, times, mask, L, Sl, batch_idx):
    def _run_svi_for_crop(c_idx):
        optimizer = numpyro.optim.Adam(step_size=0.005)
        svi = infer.SVI(my_model, ci_guide, optimizer, loss=infer.Trace_ELBO(num_particles=1))
        return, n_iter, data, times, mask, L, pi, gamma, sigma_gamma, beta_s, beta_h, omega_s, omega_h, sigma, Sl, c_obs=jnp.full(L, c_idx, dtype='int32'), progress_bar=False)
    svi_results = list(tqdm_notebook(map(_run_svi_for_crop, list(range(2))), total=2, desc="Running SVI for batch_idx {}".format(batch_idx)))
    svi_params = [result.params for result in svi_results] #[defaultdict(float) for i in range(2)] #
    return svi_params

def run_for_batch(batch_file_name):
    print("Loading data from {}".format(batch_file_name))
    data, times, batch_idx = load_data(batch_file_name)
    # for numpyro.handlers.mask
    mask = times > 0
    # data, times and mask are arrays of shape Sl x L
    # L ~ 0.5 million, Sl ~ 100

    L = data.shape[1]
    Sl = data.shape[0]
    print("Running inference for batch_idx {}".format(batch_idx))
    results = run_inference(data, times, mask, L, Sl, int(batch_idx))
    print("Saving results for batch_idx {}".format(batch_idx))

with ThreadPool(max_workers=32) as batch_pool:
    result = list(tqdm_notebook(, file_names), total=len(file_names)))

I did some search and found this issue (not exactly related) about memory and JAX cache compilation. I tried to do @jit to first _run_svi_for_crop and then to run_inference. In each case, I seem to get different errors each time I run. The types of error I get are

  • complains about model sites being duplicated
  • complains about incompatible shapes (this is because L can be different for the last batch which is < 0.5 million data points, hence the size of data, times and mask would be different for that one batch)
  • complains of a leaked value in numpyro/ (process_message)
  • assertion error about PYRO_STACK[-1] == self (I don’t remember the exact error, it only happened once)

I don’t know anything about jax.jit, including whether this is the right thing to use here, so any advice is greatly appreciated.

Maybe VAE example is helpful? FYI jax.jit(f) will compile the computation so that if you provide similar inputs (same structure, dtype, shape), the second run will be much faster. If you use jax.jit and it gives you errors, that means that it can’t compile your program. If so, you can remove jax.jit.

Sorry I forgot to link the issue! Its this one #622. In particular, this comment made me look into jax.jit as a way to circumvent the cache compilation.

I’m not using GPU unlike the above issue, but I still suspect it might be what’s causing the memory. And indeed, in one of the runs I did with jax.jit, I was able to get things working with no memory errors. But it only happened once. I’m not sure why the errors change on every run, perhaps it has to do with race condition depending on which batch is being executed first.

The leaked value error is the most common, so I could debug it further if its helpful, here’s the stacktrace

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was run_svi at /tmp/ipykernel_2834/ traced for jit.
The leaked intermediate value was created on line /home/user/cibo/numpyro/numpyro_env/lib/python3.10/site-packages/numpyro/ (process_message). 
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
/home/user/cibo/numpyro/numpyro_env/lib/python3.10/site-packages/numpyro/ (__call__)
/tmp/ipykernel_2834/ (ci_guide)
/home/user/cibo/numpyro/numpyro_env/lib/python3.10/site-packages/numpyro/ (sample)
/home/user/cibo/numpyro/numpyro_env/lib/python3.10/site-packages/numpyro/ (apply_stack)
/home/user/cibo/numpyro/numpyro_env/lib/python3.10/site-packages/numpyro/ (process_message)

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.

The VAE example you linked seems to run things sequentially, which doesn’t help my case. If there is an example that runs things in parallel, that would be very helpful, so I can see what I’m doing different that’s causing memory issues.