Hi! I’m relatively new to numpyro and pyro.
I have been looking at the time series, gaussian process and hierarchical gaussian process examples in pyro and numpyro, but am a bit confused as to how they differ when using gaussian processes for time series prediction. I understand it a in numpyro, we do not have gaussian processes or kernels in the framework itself, nor a dist.GaussianProcess.
What I’m trying to do is to modify this method from lightweighMMM
def media_mix_model(
media_data: jnp.ndarray,
target_data: jnp.ndarray,
media_prior: jnp.ndarray,
degrees_seasonality: int,
frequency: int,
transform_function: TransformFunction,
custom_priors: MutableMapping[str, Prior],
transform_kwargs: Optional[MutableMapping[str, Any]] = None,
weekday_seasonality: bool = False,
extra_features: Optional[jnp.ndarray] = None,
) -> None:
"""Media mix model.
Args:
media_data: Media data to be be used in the model.
target_data: Target data for the model.
media_prior: Cost prior for each of the media channels.
degrees_seasonality: Number of degrees of seasonality to use.
frequency: Frequency of the time span which was used to aggregate the data.
Eg. if weekly data then frequency is 52.
transform_function: Function to use to transform the media data in the
model. Currently the following are supported: 'transform_adstock',
'transform_carryover' and 'transform_hill_adstock'.
custom_priors: The custom priors we want the model to take instead of the
default ones. See our custom_priors documentation for details about the
API and possible options.
transform_kwargs: Any extra keyword arguments to pass to the transform
function. For example the adstock function can take a boolean to noramlise
output or not.
weekday_seasonality: In case of daily data you can estimate a weekday (7)
parameter.
extra_features: Extra features data to include in the model.
"""
default_priors = _get_default_priors()
data_size = media_data.shape[0]
n_channels = media_data.shape[1]
geo_shape = (media_data.shape[2],) if media_data.ndim == 3 else ()
n_geos = media_data.shape[2] if media_data.ndim == 3 else 1
with numpyro.plate(name=f"{_INTERCEPT}_plate", size=n_geos):
intercept = numpyro.sample(
name=_INTERCEPT,
fn=custom_priors.get(_INTERCEPT, default_priors[_INTERCEPT]))
with numpyro.plate(name=f"{_SIGMA}_plate", size=n_geos):
sigma = numpyro.sample(
name=_SIGMA,
fn=custom_priors.get(_SIGMA, default_priors[_SIGMA]))
# TODO(): Force all geos to have the same trend sign.
with numpyro.plate(name=f"{_COEF_TREND}_plate", size=n_geos):
coef_trend = numpyro.sample(
name=_COEF_TREND,
fn=custom_priors.get(_COEF_TREND, default_priors[_COEF_TREND]))
expo_trend = numpyro.sample(
name=_EXPO_TREND,
fn=custom_priors.get(
_EXPO_TREND, default_priors[_EXPO_TREND]))
with numpyro.plate(
name="channel_media_plate",
size=n_channels,
dim=-2 if media_data.ndim == 3 else -1):
coef_media = numpyro.sample(
name="channel_coef_media" if media_data.ndim == 3 else "coef_media",
fn=dist.HalfNormal(scale=media_prior))
if media_data.ndim == 3:
with numpyro.plate(
name="geo_media_plate",
size=n_geos,
dim=-1):
# Corrects the mean to be the same as in the channel only case.
normalisation_factor = jnp.sqrt(2.0 / jnp.pi)
coef_media = numpyro.sample(
name="coef_media",
fn=dist.HalfNormal(scale=coef_media * normalisation_factor)
)
with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_sin_cos_plate", size=2):
with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_plate",
size=degrees_seasonality):
gamma_seasonality = numpyro.sample(
name=_GAMMA_SEASONALITY,
fn=custom_priors.get(
_GAMMA_SEASONALITY, default_priors[_GAMMA_SEASONALITY]))
if weekday_seasonality:
with numpyro.plate(name=f"{_WEEKDAY}_plate", size=7):
weekday = numpyro.sample(
name=_WEEKDAY,
fn=custom_priors.get(_WEEKDAY, default_priors[_WEEKDAY]))
weekday_series = weekday[jnp.arange(data_size) % 7]
# In case of daily data, number of lags should be 13*7.
if transform_function == "carryover" and transform_kwargs and "number_lags" not in transform_kwargs:
transform_kwargs["number_lags"] = 13 * 7
elif transform_function == "carryover" and not transform_kwargs:
transform_kwargs = {"number_lags": 13 * 7}
media_transformed = numpyro.deterministic(
name="media_transformed",
value=transform_function(media_data,
custom_priors=custom_priors,
**transform_kwargs if transform_kwargs else {}))
seasonality = media_transforms.calculate_seasonality(
number_periods=data_size,
degrees=degrees_seasonality,
frequency=frequency,
gamma_seasonality=gamma_seasonality)
# For national model's case
trend = jnp.arange(data_size)
media_einsum = "tc, c -> t" # t = time, c = channel
coef_seasonality = 1
# TODO(): Add conversion of prior for HalfNormal distribution.
if media_data.ndim == 3: # For geo model's case
trend = jnp.expand_dims(trend, axis=-1)
seasonality = jnp.expand_dims(seasonality, axis=-1)
media_einsum = "tcg, cg -> tg" # t = time, c = channel, g = geo
if weekday_seasonality:
weekday_series = jnp.expand_dims(weekday_series, axis=-1)
with numpyro.plate(name="seasonality_plate", size=n_geos):
coef_seasonality = numpyro.sample(
name=_COEF_SEASONALITY,
fn=custom_priors.get(
_COEF_SEASONALITY, default_priors[_COEF_SEASONALITY]))
# expo_trend is B(1, 1) so that the exponent on time is in [.5, 1.5].
prediction = (
intercept + coef_trend * trend ** expo_trend +
seasonality * coef_seasonality +
jnp.einsum(media_einsum, media_transformed, coef_media))
if extra_features is not None:
plate_prefixes = ("extra_feature",)
extra_features_einsum = "tf, f -> t" # t = time, f = feature
extra_features_plates_shape = (extra_features.shape[1],)
if extra_features.ndim == 3:
plate_prefixes = ("extra_feature", "geo")
extra_features_einsum = "tfg, fg -> tg" # t = time, f = feature, g = geo
extra_features_plates_shape = (extra_features.shape[1], *geo_shape)
with numpyro.plate_stack(plate_prefixes,
sizes=extra_features_plates_shape):
coef_extra_features = numpyro.sample(
name=_COEF_EXTRA_FEATURES,
fn=custom_priors.get(
_COEF_EXTRA_FEATURES, default_priors[_COEF_EXTRA_FEATURES]))
extra_features_effect = jnp.einsum(extra_features_einsum,
extra_features,
coef_extra_features)
prediction += extra_features_effect
if weekday_seasonality:
prediction += weekday_series
mu = numpyro.deterministic(name="mu", value=prediction)
numpyro.sample(
name="target", fn=dist.Normal(loc=mu, scale=sigma), obs=target_data)
I would like to incorporate that the ROI (return on investment) for having more marketing spend on a particular channel depends on the seasonality and the media channel
- Bayesian Media Mix Models: Modelling changes in marketing effectiveness over time - PyMC Labs
- You’re probably modeling seasonality the wrong way - Recast
as this example code.
def media_mix_seasonal_model(
dates, # array of dates or time periods
sales, # array of sales data
media_spend, # array of media spend data
seasonality, # array of seasonality data, e.g., indicator variables for seasons, holidays, etc.
):
# Priors for base demand and marketing effect
base_demand = numpyro.sample("base_demand", dist.Normal(0, 1))
marketing_effect = numpyro.sample("marketing_effect", dist.Normal(0, 1))
# Use a Gaussian Process to model dynamic ROI influenced by seasonality
# Kernel parameters for the GP
gp_length_scale = numpyro.sample("gp_length_scale", dist.LogNormal(0., 1.))
gp_output_scale = numpyro.sample("gp_output_scale", dist.LogNormal(0., 1.))
# Covariance function for the GP
# We assume that the ROI is a smooth function of time and has yearly seasonality
kernel = gp_output_scale * dist.kernels.ExpQuad(gp_length_scale) * dist.kernels.ExpSinSquared(period=365.25)
gp = numpyro.sample("gp", dist.GaussianProcess(kernel, index_points=dates.reshape(-1, 1)))
# Expected sales are a combination of base demand, marketing spend, and a seasonally-varying ROI
expected_sales = base_demand + jnp.exp(gp) * media_spend * seasonality
# Likelihood of observing the sales data
numpyro.sample("obs", dist.Normal(expected_sales, 0.1), obs=sales)
However, every time I try to incorporate a gaussian process prior in my model, it is either FAR to slow to train OR the gaussian process prior will depend on the input shape (training) and not work when running the model for prediction.
Here is one of many attempt to do that:
def _media_mix_model_time_varying_kernel(
media_data: jnp.ndarray,
target_data: jnp.ndarray,
media_prior: jnp.ndarray,
degrees_seasonality: int,
frequency: int,
transform_function: TransformFunction,
custom_priors: MutableMapping[str, Prior],
transform_kwargs: Optional[MutableMapping[str, Any]] = None,
weekday_seasonality: bool = False,
extra_features: Optional[jnp.ndarray] = None,
time_varying=True,
) -> None:
"""Media mix model.
Args:
media_data: Media data to be be used in the model.
target_data: Target data for the model.
media_prior: Cost prior for each of the media channels.
degrees_seasonality: Number of degrees of seasonality to use.
frequency: Frequency of the time span which was used to aggregate the data.
Eg. if weekly data then frequency is 52.
transform_function: Function to use to transform the media data in the
model. Currently the following are supported: 'transform_adstock',
'transform_carryover' and 'transform_hill_adstock'.
custom_priors: The custom priors we want the model to take instead of the
default ones. See our custom_priors documentation for details about the
API and possible options.
transform_kwargs: Any extra keyword arguments to pass to the transform
function. For example the adstock function can take a boolean to noramlise
output or not.
weekday_seasonality: In case of daily data you can estimate a weekday (7)
parameter.
extra_features: Extra features data to include in the model.
"""
default_priors = _get_default_priors()
data_size = media_data.shape[0]
n_channels = media_data.shape[1]
geo_shape = (media_data.shape[2],) if media_data.ndim == 3 else ()
n_geos = media_data.shape[2] if media_data.ndim == 3 else 1
with numpyro.plate(name=f"{_INTERCEPT}_plate", size=n_geos):
intercept = numpyro.sample(
name=_INTERCEPT,
fn=custom_priors.get(_INTERCEPT, default_priors[_INTERCEPT]))
with numpyro.plate(name=f"{_SIGMA}_plate", size=n_geos):
sigma = numpyro.sample(
name=_SIGMA,
fn=custom_priors.get(_SIGMA, default_priors[_SIGMA]))
# TODO(): Force all geos to have the same trend sign.
with numpyro.plate(name=f"{_COEF_TREND}_plate", size=n_geos):
coef_trend = numpyro.sample(
name=_COEF_TREND,
fn=custom_priors.get(_COEF_TREND, default_priors[_COEF_TREND]))
expo_trend = numpyro.sample(
name=_EXPO_TREND,
fn=custom_priors.get(
_EXPO_TREND, default_priors[_EXPO_TREND]))
with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_sin_cos_plate", size=2):
with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_plate",
size=degrees_seasonality):
gamma_seasonality = numpyro.sample(
name=_GAMMA_SEASONALITY,
fn=custom_priors.get(
_GAMMA_SEASONALITY, default_priors[_GAMMA_SEASONALITY]))
if weekday_seasonality:
with numpyro.plate(name=f"{_WEEKDAY}_plate", size=7):
weekday = numpyro.sample(
name=_WEEKDAY,
fn=custom_priors.get(_WEEKDAY, default_priors[_WEEKDAY]))
weekday_series = weekday[jnp.arange(data_size) % 7]
# In case of daily data, number of lags should be 13*7.
if transform_function == "carryover" and transform_kwargs and "number_lags" not in transform_kwargs:
transform_kwargs["number_lags"] = 13 * 7
elif transform_function == "carryover" and not transform_kwargs:
transform_kwargs = {"number_lags": 13 * 7}
media_transformed = numpyro.deterministic(
name="media_transformed",
value=transform_function(media_data,
custom_priors=custom_priors,
**transform_kwargs if transform_kwargs else {}))
seasonality = media_transforms.calculate_seasonality(
number_periods=data_size,
degrees=degrees_seasonality,
frequency=frequency,
gamma_seasonality=gamma_seasonality)
# For national model's case
trend = jnp.arange(data_size)
media_einsum = "tc, c -> t" # t = time, c = channel
coef_seasonality = 1
if media_data.ndim == 3: # For geo model's case
trend = jnp.expand_dims(trend, axis=-1)
seasonality = jnp.expand_dims(seasonality, axis=-1)
media_einsum = "tcg, cg -> tg" # t = time, c = channel, g = geo
if weekday_seasonality:
weekday_series = jnp.expand_dims(weekday_series, axis=-1)
with numpyro.plate(name="seasonality_plate", size=n_geos):
coef_seasonality = numpyro.sample(
name=_COEF_SEASONALITY,
fn=custom_priors.get(
_COEF_SEASONALITY, default_priors[_COEF_SEASONALITY]))
with numpyro.plate("channel_plate", n_channels):
length_scales = numpyro.sample("length_scales", dist.LogNormal(0, 1))
rbf_length_scales = numpyro.sample("rbf_length_scales", dist.LogNormal(0, 1))
periodic_length_scales = numpyro.sample("periodic_length_scales", dist.LogNormal(0, 1))
# 0. Basic basis kernel - Does not need the periodic and combined kernels
# # Create the kernel matrix for each channel and stack them together
# Ks = [rbf_kernel(jnp.arange(data_size)[:, None], jnp.arange(data_size)[:, None], l) for l in length_scales]
# K = jnp.stack(Ks, axis=0) # Shape: (n_channels, data_size, data_size)
# K += jnp.eye(data_size) * 1e-6 # for numerical stability
# # ----------------------------------------------------------
# 1. Full Gaussian process
# Create the kernel matrix for each channel using combined kernel and stack them together
# Ks = [combined_kernel(jnp.arange(data_size)[:, None], jnp.arange(data_size)[:, None], periodic_kernel, r, p) for r, p in zip(rbf_length_scales, periodic_length_scales)]
# K = jnp.stack(Ks, axis=0)
# K += jnp.eye(data_size) * 1e-6 # for numerical stability
# # ----------------------------------------------------------
# # 2. Sparse
# # Number of inducing points
# num_inducing = 3 # or any number much smaller than data_size # Select inducing points (you can choose these more wisely, this is just an example)
# inducing_points = jnp.sort(jax.random.choice(jax.random.PRNGKey(0), jnp.arange(data_size), shape=(num_inducing,), replace=False))
# Ks = [
# sparse_gaussian_process_kernel(jnp.arange(data_size), inducing_points, yearly_periodic_kernel, r, p)
# for r, p in zip(rbf_length_scales, periodic_length_scales)
# ]
# K = jnp.stack(Ks, axis=0)
# K += jnp.eye(data_size) * 1e-6 # for numerical stability
K = get_kernel_matrix("sparse", data_size, rbf_length_scales, periodic_length_scales, period=frequency)
# Step 2: Use the GP covariance matrix to sample the coefficients for each marketing channel over time
with numpyro.plate("geo_media_plate", n_geos):
coef_media = numpyro.sample(
"coef_media",
dist.MultivariateNormal(loc=jnp.zeros(data_size), covariance_matrix=K).to_event(1)
)
# Sum across channels
print(f"{media_einsum=}, {media_data.shape=}, {coef_media.shape=}")
media_spend_temp = media_data[:, :, :, None] * coef_media.transpose(3, 1, 0, 2)
media_spend = jnp.sum(media_spend_temp, axis=1)
media_spend = jnp.sum(media_spend, axis=2)
# expo_trend is B(1, 1) so that the exponent on time is in [.5, 1.5].
prediction = (
intercept + coef_trend * trend ** expo_trend +
seasonality * coef_seasonality +
media_spend
)
if extra_features is not None:
plate_prefixes = ("extra_feature",)
extra_features_einsum = "tf, f -> t" # t = time, f = feature
extra_features_plates_shape = (extra_features.shape[1],)
if extra_features.ndim == 3:
plate_prefixes = ("extra_feature", "geo")
extra_features_einsum = "tfg, fg -> tg" # t = time, f = feature, g = geo
extra_features_plates_shape = (extra_features.shape[1], *geo_shape)
with numpyro.plate_stack(plate_prefixes,
sizes=extra_features_plates_shape):
coef_extra_features = numpyro.sample(
name=_COEF_EXTRA_FEATURES,
fn=custom_priors.get(
_COEF_EXTRA_FEATURES, default_priors[_COEF_EXTRA_FEATURES]))
extra_features_effect = jnp.einsum(extra_features_einsum,
extra_features,
coef_extra_features)
prediction += extra_features_effect
if weekday_seasonality:
prediction += weekday_series
mu = numpyro.deterministic(name="mu", value=prediction)
numpyro.sample(
name="target", fn=dist.Normal(loc=mu, scale=sigma), obs=target_data)