I have a hierarchical poisson GLM in numpyro that I’d like to scale from 15 million rows of data (takes ~15 minutes to fit locally) to 15 billion rows of data.
Are there any examples of how to scale models in production, such as utilizing GPU, parallelizing, subsampling, etc? I’m having trouble finding any. Would also be willing to use Pyro if its easier, but my model definitely seems more performant in numpyro currently.
Here’s an example of how I’m calling my model and guide functions
def prep_data(data):
# Get data at 1 row per sku
sku_level = data.groupby("sku").first()[
['product_id', 'category']
+ data.filter(like="size_").columns.tolist()
]
results = [
sku_level.product_id.cat.codes.values,
data.sku.cat.codes.values,
sku_level.category.cat.codes.values,
data.seasonality.cat.codes.values,
data.month.values-1,
data.dow.values,
data.hour.values,
sku_level.filter(like='size_').iloc[:,1:].values,
data.time_since_published.values,
data.rentals_obs.values,
data.observed.values
]
return results
optimizer = numpyro.optim.Adam(step_size=0.0005)
svi = numpyro.infer.SVI(
hierarchical_poisson_censored_regression,
hierarchical_guide,
optimizer, loss=Trace_ELBO())
inputs = prep_data(data)
svi_result = svi.run(random.PRNGKey(0), 2000, *inputs)
params = svi_result.params
# get posterior samples
predictive = Predictive(hierarchical_guide, params=params, num_samples=1000)
samples = predictive(random.PRNGKey(1), *inputs)
And here’s the model
def poisson_logcdf(value, mu):
safe_mu = jnp.where(jnp.less(mu, 0), 0, mu)
safe_value = jnp.where(jnp.less(value, 0), 0, value)
res = jnp.where(
jnp.less(value, 0),
-jnp.inf,
jnp.log(jsp.gammaincc(safe_value + 1, safe_mu)),
)
return res
def poisson_cdf(value, mu):
return jnp.exp(poisson_logcdf(value, mu))
def hierarchical_poisson_censored_regression(
product_j,
sku_k,
category_code,
season_code,
month_code,
dow_code,
hour_code,
size_matrix,
days_since_publish,
y,
E
):
# Shapes
n_products = len(np.unique(product_j))
n_skus = len(np.unique(sku_k))
n_categories = len(np.unique(category_code))
n_sizes = size_matrix.shape[1]
# Masks
censored_mask = (E == 0)
observed_mask = (E == 1)
# Product level Coefficients
mu_alpha_product = numpyro.sample("mu_alpha_product", dist.Normal(-4, 1))
sigma_alpha_product = numpyro.sample("sigma_alpha_product", dist.Exponential(2.5))
with numpyro.plate("products", n_products):
alpha_product = numpyro.sample("alpha_product", dist.Normal(mu_alpha_product, sigma_alpha_product) )
with numpyro.plate("categories", n_categories):
b_category = numpyro.sample("b_category", dist.Normal(0, .5) )
with numpyro.plate("sizes", n_sizes):
b_size = numpyro.sample("b_size", dist.Normal(0, .5) )
# SKU level Coefficients
b_decay = numpyro.sample("b_decay", dist.Normal(-0.4, 0.1))
with numpyro.plate("dayofweek", 7):
b_dow = numpyro.sample("b_dow", dist.Normal(0, .5))
with numpyro.plate("hours", 24):
b_hour = numpyro.sample("b_hour", dist.Normal(0, .5))
with numpyro.plate("months", 12):
b_month = numpyro.sample("b_month", dist.Normal(0, .5))
with numpyro.plate("seasonalities", 3):
b_season = numpyro.sample("b_season", dist.Normal(0, .5))
# Top-level hierarchy, product > skus
mu_alpha = alpha_product[product_j] + b_category[category_code] + b_size @ size_matrix.T
sigma_alpha = numpyro.sample("sigma_alpha", dist.Exponential(2.5))
with numpyro.plate("skus", n_skus):
alpha = numpyro.sample("alpha", dist.Normal(mu_alpha, sigma_alpha))
# regression with sku random effects
log_lambd = numpyro.deterministic("log_lambd",
alpha[sku_k] + b_month[month_code] * b_season[season_code] + b_dow[dow_code] + b_hour[hour_code]
+ b_decay * days_since_publish / 365
)
mu = jnp.exp(log_lambd)
with numpyro.plate("data", days_since_publish.shape[0]):
# Observed
numpyro.sample("obs", dist.Poisson(mu).mask(observed_mask), obs=y)
# Censored
censored_prob = 1 - poisson_cdf(y, mu=mu)
numpyro.sample("censored_label", dist.Bernoulli(censored_prob).mask(censored_mask), obs=E)