How to reduce the memory usage?

Hello,

I’m having a “Memory Exhausion” issue:

Is there a way to reduce the usage of memory of the algorithm? (Some sort of batch_size?)

Thanks in advance

without further information about your model (and ideally some code) i’m afraid it’s impossible to answer here

I’m using a code from the internet is this:

def simple_elasticity_model(log_price, sku_idx, log_quantity=None):
    n_obs = log_price.size
    n_sku = np.unique(sku_idx).size

    with numpyro.plate("sku", n_sku):
        sku_intercept = numpyro.sample("sku_intercept", dist.Normal(loc=0, scale=1))
        beta_log_price = numpyro.sample("beta_log_price", dist.Normal(loc=0, scale=1))
        sigma_sku = numpyro.sample("sigma", dist.HalfNormal(scale=1))

    mu = beta_log_price[sku_idx] * log_price + sku_intercept[sku_idx]

    sigma = sigma_sku[sku_idx]

    with numpyro.plate("data", n_obs):
        numpyro.sample("obs", dist.Normal(loc=mu, scale=sigma), obs=log_quantity)

Then the training step:

simple_guide = AutoNormal(simple_elasticity_model)
simple_optimizer = numpyro.optim.Adam(step_size=0.01)

simple_svi = SVI(
    simple_elasticity_model,
    simple_guide,
    simple_optimizer,
    loss=Trace_ELBO(),
)

num_steps = 25_000

rng_key, rng_subkey = random.split(key=rng_key)
simple_svi_result = simple_svi.run(
    rng_subkey,
    num_steps,
    log_price,
    sku_idx,
    log_quantity,
)

fig, ax = plt.subplots(figsize=(9, 6))
ax.plot(simple_svi_result.losses)
ax.set_title("ELBO loss", fontsize=18, fontweight="bold");

That is all I have of the model.

you can do mini-batching by using the subsample_size arg in your plate(s), see e.g. here

I changed it to 1:

def simple_elasticity_model(log_price, sku_idx, log_quantity=None):
    n_obs = log_price.size
    n_sku = np.unique(sku_idx).size

    with numpyro.plate("sku", n_sku, subsample_size=1):
        sku_intercept = numpyro.sample("sku_intercept", dist.Normal(loc=0, scale=1))
        beta_log_price = numpyro.sample("beta_log_price", dist.Normal(loc=0, scale=1))
        sigma_sku = numpyro.sample("sigma", dist.HalfNormal(scale=1))

    mu = beta_log_price[sku_idx] * log_price + sku_intercept[sku_idx]

    sigma = sigma_sku[sku_idx]

    with numpyro.plate("data", n_obs):
        numpyro.sample("obs", dist.Normal(loc=mu, scale=sigma), obs=log_quantity)

But I’m getting the same error, is there another way to control the amount of data the model is receiving.

its probably the second plate that is the issue. see SVI Part II: Conditional Independence, Subsampling, and Amortization — Pyro Tutorials 1.9.1 documentation

Martin is probably right about the all-data likelihood causing the crash. If you still run into memory issues when using the indices from plate('data', n_obs, subsample_size=k) for k<<n_obs you should have a look at dataloaders in Jax: Training a simple neural network, with tensorflow/datasets data loading — JAX documentation. You will still need to do the appropriate likelihood scaling either using the plate or manually using handlers.scale. See SVI Part II that Martin linked.