I’m having issues running Numpyro’s SVI in parallel. I have a dataset of ~billion data points, which I have broken up into batches of ~0.5 million data points. If I run SVI for a single batch, the whole thing uses < 3GB memory. But if I run two batches in parallel, the process runs out of memory, on a machine with 128 GB RAM.
My model and guide are defined in this post, and I’m using the following methods to run things in parallel
def run_inference(data, times, mask, L, Sl, batch_idx):
def _run_svi_for_crop(c_idx):
optimizer = numpyro.optim.Adam(step_size=0.005)
svi = infer.SVI(my_model, ci_guide, optimizer, loss=infer.Trace_ELBO(num_particles=1))
return svi.run(jrng_key, n_iter, data, times, mask, L, pi, gamma, sigma_gamma, beta_s, beta_h, omega_s, omega_h, sigma, Sl, c_obs=jnp.full(L, c_idx, dtype='int32'), progress_bar=False)
svi_results = list(tqdm_notebook(map(_run_svi_for_crop, list(range(2))), total=2, desc="Running SVI for batch_idx {}".format(batch_idx)))
svi_params = [result.params for result in svi_results] #[defaultdict(float) for i in range(2)] #
return svi_params
def run_for_batch(batch_file_name):
print("Loading data from {}".format(batch_file_name))
data, times, batch_idx = load_data(batch_file_name)
# for numpyro.handlers.mask
mask = times > 0
# data, times and mask are arrays of shape Sl x L
# L ~ 0.5 million, Sl ~ 100
L = data.shape[1]
Sl = data.shape[0]
print("Running inference for batch_idx {}".format(batch_idx))
results = run_inference(data, times, mask, L, Sl, int(batch_idx))
print("Saving results for batch_idx {}".format(batch_idx))
save_results(results)
with ThreadPool(max_workers=32) as batch_pool:
result = list(tqdm_notebook(batch_pool.map(run_for_batch, file_names), total=len(file_names)))