Scaling Numpyro Model to 5B rows

I have a hierarchical poisson GLM in numpyro that I’d like to scale from 15 million rows of data (takes ~15 minutes to fit locally) to 15 billion rows of data.

Are there any examples of how to scale models in production, such as utilizing GPU, parallelizing, subsampling, etc? I’m having trouble finding any. Would also be willing to use Pyro if its easier, but my model definitely seems more performant in numpyro currently.

Here’s an example of how I’m calling my model and guide functions

def prep_data(data):

    # Get data at 1 row per sku
    sku_level = data.groupby("sku").first()[
        ['product_id', 'category'] 
        + data.filter(like="size_").columns.tolist()
    ]
    
    results =  [
        sku_level.product_id.cat.codes.values,
        data.sku.cat.codes.values,
        sku_level.category.cat.codes.values,
        data.seasonality.cat.codes.values,
        data.month.values-1,
        data.dow.values,
        data.hour.values,
        sku_level.filter(like='size_').iloc[:,1:].values,
        data.time_since_published.values,
        data.rentals_obs.values,
        data.observed.values
    ]
    return results

optimizer = numpyro.optim.Adam(step_size=0.0005)
svi = numpyro.infer.SVI(
    hierarchical_poisson_censored_regression, 
    hierarchical_guide, 
    optimizer, loss=Trace_ELBO())

inputs = prep_data(data)
svi_result = svi.run(random.PRNGKey(0), 2000, *inputs)
params = svi_result.params

# get posterior samples
predictive = Predictive(hierarchical_guide, params=params, num_samples=1000)
samples = predictive(random.PRNGKey(1), *inputs)

And here’s the model

def poisson_logcdf(value, mu):
    safe_mu = jnp.where(jnp.less(mu, 0), 0, mu)
    safe_value = jnp.where(jnp.less(value, 0), 0, value)

    res = jnp.where(
       jnp.less(value, 0),
        -jnp.inf,
        jnp.log(jsp.gammaincc(safe_value + 1, safe_mu)),
    )
    return res

def poisson_cdf(value, mu):
    return jnp.exp(poisson_logcdf(value, mu))


def hierarchical_poisson_censored_regression(
    product_j,
    sku_k,
    category_code,
    season_code,
    month_code,
    dow_code,
    hour_code,
    size_matrix,
    days_since_publish,
    y,
    E
):
    # Shapes
    n_products = len(np.unique(product_j))
    n_skus = len(np.unique(sku_k))
    n_categories = len(np.unique(category_code))
    n_sizes = size_matrix.shape[1]

    # Masks
    censored_mask = (E == 0)
    observed_mask = (E == 1)

    # Product level Coefficients
    mu_alpha_product = numpyro.sample("mu_alpha_product", dist.Normal(-4, 1))
    sigma_alpha_product = numpyro.sample("sigma_alpha_product", dist.Exponential(2.5))
    with numpyro.plate("products", n_products):
        alpha_product = numpyro.sample("alpha_product", dist.Normal(mu_alpha_product, sigma_alpha_product) ) 
    with numpyro.plate("categories", n_categories):
        b_category = numpyro.sample("b_category", dist.Normal(0, .5) )
    with numpyro.plate("sizes", n_sizes):
        b_size = numpyro.sample("b_size", dist.Normal(0, .5) )

    # SKU level Coefficients
    b_decay = numpyro.sample("b_decay", dist.Normal(-0.4, 0.1))
    with numpyro.plate("dayofweek", 7):
        b_dow = numpyro.sample("b_dow", dist.Normal(0, .5))
    with numpyro.plate("hours", 24):
        b_hour = numpyro.sample("b_hour", dist.Normal(0, .5))
    with numpyro.plate("months", 12):
        b_month = numpyro.sample("b_month", dist.Normal(0, .5))
    with numpyro.plate("seasonalities", 3):
        b_season = numpyro.sample("b_season", dist.Normal(0, .5))

    # Top-level hierarchy, product > skus
    mu_alpha = alpha_product[product_j] + b_category[category_code] + b_size @ size_matrix.T
    sigma_alpha = numpyro.sample("sigma_alpha", dist.Exponential(2.5))
    with numpyro.plate("skus", n_skus):
        alpha = numpyro.sample("alpha", dist.Normal(mu_alpha, sigma_alpha))
    
    # regression with sku random effects
    log_lambd = numpyro.deterministic("log_lambd",
        alpha[sku_k] + b_month[month_code] * b_season[season_code] + b_dow[dow_code] + b_hour[hour_code] 
                + b_decay * days_since_publish / 365
    )
    mu = jnp.exp(log_lambd)

    with numpyro.plate("data", days_since_publish.shape[0]):
        # Observed
        numpyro.sample("obs", dist.Poisson(mu).mask(observed_mask), obs=y)

        # Censored
        censored_prob = 1 - poisson_cdf(y, mu=mu)
        numpyro.sample("censored_label", dist.Bernoulli(censored_prob).mask(censored_mask), obs=E)
1 Like

e.g. this example uses SVI + data subsampling (mini-batching)

If you are using SVI, it’s easy to add a subsampling arg. Now I’m wondering if NUTS supports subsampling.

Thank you both, and appreciate the example.

is using subsample_size in a plate the same as minibatching? I’m not seeing much improvement from doing so

    with numpyro.plate("data", days_since_publish.shape[0], subsample_size=500) as ind:

        # Observed
        numpyro.sample("obs", dist.Poisson(mu[ind]), obs=y[ind])

        # # Censored
        censored_prob = 1 - poisson_cdf(y, mu=mu)
        numpyro.sample("censored_label", dist.Bernoulli(censored_prob[ind]).mask(censored_mask[ind]), obs=E[ind])

My model fit in 15 minutes in 13 million rows using this subsample_size, but meanwhile it fit in 17 minutes on the same data without any subsample_size args:

    with numpyro.plate("data", days_since_publish.shape[0] ):

        # Observed
        numpyro.sample("obs", dist.Poisson(mu]), obs=y])

        # # Censored
        censored_prob = 1 - poisson_cdf(y, mu=mu)
        numpyro.sample("censored_label", dist.Bernoulli(censored_prob]).mask(censored_mask]), obs=E])

Is there just something else Im doing wrong there? Maybe it just takes much larger data to see size-able improvements?

didn’t look at your model in detail but it appears this line operates on all data?
can this not be something like

censored_prob = 1 - poisson_cdf(y[ind], mu=mu[ind])
        numpyro.sample("censored_label", dist.Bernoulli(censored_prob).mask(censored_mask[ind]), obs=E[ind])
1 Like

Thank you! I didnt quite understand the order that operations were happening before you said this. My model now fits in 15 seconds and converges on the correct parameter estimates. Putting calculations under the subsampled plate did the trick

    with numpyro.plate("data", days_since_publish.shape[0], subsample_size=5000) as ind:

        # Masks
        censored_mask = (E[ind] == 0)
        observed_mask = (E[ind] == 1)

        # regression with sku random effects
        log_lambd = numpyro.deterministic("log_lambd",
            alpha[sku_k[ind]] 
            + b_month[month_code[ind]] * (1 + b_FW * fall_winter_ind[ind] + b_SL * seasonless_ind[ind]) 
            + b_dow[dow_code[ind]] + b_hour[hour_code[ind]] 
            + b_decay * days_since_publish[ind] / 365
        )
        mu = jnp.exp(log_lambd)

        # Observed
        numpyro.sample("obs", dist.Poisson(mu).mask(observed_mask), obs=y[ind])

        # # Censored
        censored_prob = 1 - poisson_cdf(y[ind], mu=mu)
        numpyro.sample("censored_label", dist.Bernoulli(censored_prob).mask(censored_mask), obs=E[ind])

in practice when you scale to much larger data you’ll probably want to use a bigger batch size (especially if you’re on gpu), basically whatever saturates your compute, although finding the “optimal” value might take some experimentation

1 Like

Hey this method and example looks great! Just want to confirm this is samples with replacement, isn’t it? If I want to run a standard mini-batch svi with epoch and without replacement sampling, is there any way besides using init(), update() and evaluate() of SVI class? @martinjankowiak

it is without replacement, which is what’s standard practice in mini-batch SGD-like training

Cool that’s good to know, also does that mean the plate and subsample_size will loop through the epoch and num_steps in SVI.run() is the number of epoch?

no it’s completely iid. you’d need a custom SVI.run() method if you want special mini-batching logic

got it, that makes sense, thanks for your help!