Hierarchical Gaussian process in numpyro for marketing mixed model?

Hi! I’m relatively new to numpyro and pyro.

I have been looking at the time series, gaussian process and hierarchical gaussian process examples in pyro and numpyro, but am a bit confused as to how they differ when using gaussian processes for time series prediction. I understand it a in numpyro, we do not have gaussian processes or kernels in the framework itself, nor a dist.GaussianProcess.

What I’m trying to do is to modify this method from lightweighMMM


def media_mix_model(
    media_data: jnp.ndarray,
    target_data: jnp.ndarray,
    media_prior: jnp.ndarray,
    degrees_seasonality: int,
    frequency: int,
    transform_function: TransformFunction,
    custom_priors: MutableMapping[str, Prior],
    transform_kwargs: Optional[MutableMapping[str, Any]] = None,
    weekday_seasonality: bool = False,
    extra_features: Optional[jnp.ndarray] = None,
) -> None:
  """Media mix model.

  Args:
    media_data: Media data to be be used in the model.
    target_data: Target data for the model.
    media_prior: Cost prior for each of the media channels.
    degrees_seasonality: Number of degrees of seasonality to use.
    frequency: Frequency of the time span which was used to aggregate the data.
      Eg. if weekly data then frequency is 52.
    transform_function: Function to use to transform the media data in the
      model. Currently the following are supported: 'transform_adstock',
        'transform_carryover' and 'transform_hill_adstock'.
    custom_priors: The custom priors we want the model to take instead of the
      default ones. See our custom_priors documentation for details about the
      API and possible options.
    transform_kwargs: Any extra keyword arguments to pass to the transform
      function. For example the adstock function can take a boolean to noramlise
      output or not.
    weekday_seasonality: In case of daily data you can estimate a weekday (7)
      parameter.
    extra_features: Extra features data to include in the model.
  """
  default_priors = _get_default_priors()
  data_size = media_data.shape[0]
  n_channels = media_data.shape[1]
  geo_shape = (media_data.shape[2],) if media_data.ndim == 3 else ()
  n_geos = media_data.shape[2] if media_data.ndim == 3 else 1

  with numpyro.plate(name=f"{_INTERCEPT}_plate", size=n_geos):
    intercept = numpyro.sample(
        name=_INTERCEPT,
        fn=custom_priors.get(_INTERCEPT, default_priors[_INTERCEPT]))

  with numpyro.plate(name=f"{_SIGMA}_plate", size=n_geos):
    sigma = numpyro.sample(
        name=_SIGMA,
        fn=custom_priors.get(_SIGMA, default_priors[_SIGMA]))

  # TODO(): Force all geos to have the same trend sign.
  with numpyro.plate(name=f"{_COEF_TREND}_plate", size=n_geos):
    coef_trend = numpyro.sample(
        name=_COEF_TREND,
        fn=custom_priors.get(_COEF_TREND, default_priors[_COEF_TREND]))

  expo_trend = numpyro.sample(
      name=_EXPO_TREND,
      fn=custom_priors.get(
          _EXPO_TREND, default_priors[_EXPO_TREND]))

  with numpyro.plate(
      name="channel_media_plate",
      size=n_channels,
      dim=-2 if media_data.ndim == 3 else -1):
    coef_media = numpyro.sample(
        name="channel_coef_media" if media_data.ndim == 3 else "coef_media",
        fn=dist.HalfNormal(scale=media_prior))
    if media_data.ndim == 3:
      with numpyro.plate(
          name="geo_media_plate",
          size=n_geos,
          dim=-1):
        # Corrects the mean to be the same as in the channel only case.
        normalisation_factor = jnp.sqrt(2.0 / jnp.pi)
        coef_media = numpyro.sample(
            name="coef_media",
            fn=dist.HalfNormal(scale=coef_media * normalisation_factor)
        )

  with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_sin_cos_plate", size=2):
    with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_plate",
                       size=degrees_seasonality):
      gamma_seasonality = numpyro.sample(
          name=_GAMMA_SEASONALITY,
          fn=custom_priors.get(
              _GAMMA_SEASONALITY, default_priors[_GAMMA_SEASONALITY]))

  if weekday_seasonality:
    with numpyro.plate(name=f"{_WEEKDAY}_plate", size=7):
      weekday = numpyro.sample(
          name=_WEEKDAY,
          fn=custom_priors.get(_WEEKDAY, default_priors[_WEEKDAY]))
    weekday_series = weekday[jnp.arange(data_size) % 7]
    # In case of daily data, number of lags should be 13*7.
    if transform_function == "carryover" and transform_kwargs and "number_lags" not in transform_kwargs:
      transform_kwargs["number_lags"] = 13 * 7
    elif transform_function == "carryover" and not transform_kwargs:
      transform_kwargs = {"number_lags": 13 * 7}

  media_transformed = numpyro.deterministic(
      name="media_transformed",
      value=transform_function(media_data,
                               custom_priors=custom_priors,
                               **transform_kwargs if transform_kwargs else {}))
  seasonality = media_transforms.calculate_seasonality(
      number_periods=data_size,
      degrees=degrees_seasonality,
      frequency=frequency,
      gamma_seasonality=gamma_seasonality)
  # For national model's case
  trend = jnp.arange(data_size)
  media_einsum = "tc, c -> t"  # t = time, c = channel
  coef_seasonality = 1

  # TODO(): Add conversion of prior for HalfNormal distribution.
  if media_data.ndim == 3:  # For geo model's case
    trend = jnp.expand_dims(trend, axis=-1)
    seasonality = jnp.expand_dims(seasonality, axis=-1)
    media_einsum = "tcg, cg -> tg"  # t = time, c = channel, g = geo
    if weekday_seasonality:
      weekday_series = jnp.expand_dims(weekday_series, axis=-1)
    with numpyro.plate(name="seasonality_plate", size=n_geos):
      coef_seasonality = numpyro.sample(
          name=_COEF_SEASONALITY,
          fn=custom_priors.get(
              _COEF_SEASONALITY, default_priors[_COEF_SEASONALITY]))
  # expo_trend is B(1, 1) so that the exponent on time is in [.5, 1.5].
  prediction = (
      intercept + coef_trend * trend ** expo_trend +
      seasonality * coef_seasonality +
      jnp.einsum(media_einsum, media_transformed, coef_media))
  if extra_features is not None:
    plate_prefixes = ("extra_feature",)
    extra_features_einsum = "tf, f -> t"  # t = time, f = feature
    extra_features_plates_shape = (extra_features.shape[1],)
    if extra_features.ndim == 3:
      plate_prefixes = ("extra_feature", "geo")
      extra_features_einsum = "tfg, fg -> tg"  # t = time, f = feature, g = geo
      extra_features_plates_shape = (extra_features.shape[1], *geo_shape)
    with numpyro.plate_stack(plate_prefixes,
                             sizes=extra_features_plates_shape):
      coef_extra_features = numpyro.sample(
          name=_COEF_EXTRA_FEATURES,
          fn=custom_priors.get(
              _COEF_EXTRA_FEATURES, default_priors[_COEF_EXTRA_FEATURES]))
    extra_features_effect = jnp.einsum(extra_features_einsum,
                                       extra_features,
                                       coef_extra_features)
    prediction += extra_features_effect

  if weekday_seasonality:
    prediction += weekday_series
  mu = numpyro.deterministic(name="mu", value=prediction)

  numpyro.sample(
      name="target", fn=dist.Normal(loc=mu, scale=sigma), obs=target_data)

I would like to incorporate that the ROI (return on investment) for having more marketing spend on a particular channel depends on the seasonality and the media channel

def media_mix_seasonal_model(
    dates,  # array of dates or time periods
    sales,  # array of sales data
    media_spend,  # array of media spend data
    seasonality,  # array of seasonality data, e.g., indicator variables for seasons, holidays, etc.
):
    # Priors for base demand and marketing effect
    base_demand = numpyro.sample("base_demand", dist.Normal(0, 1))
    marketing_effect = numpyro.sample("marketing_effect", dist.Normal(0, 1))

    # Use a Gaussian Process to model dynamic ROI influenced by seasonality
    # Kernel parameters for the GP
    gp_length_scale = numpyro.sample("gp_length_scale", dist.LogNormal(0., 1.))
    gp_output_scale = numpyro.sample("gp_output_scale", dist.LogNormal(0., 1.))

    # Covariance function for the GP
    # We assume that the ROI is a smooth function of time and has yearly seasonality
    kernel = gp_output_scale * dist.kernels.ExpQuad(gp_length_scale) * dist.kernels.ExpSinSquared(period=365.25)
    gp = numpyro.sample("gp", dist.GaussianProcess(kernel, index_points=dates.reshape(-1, 1)))

    # Expected sales are a combination of base demand, marketing spend, and a seasonally-varying ROI
    expected_sales = base_demand + jnp.exp(gp) * media_spend * seasonality

    # Likelihood of observing the sales data
    numpyro.sample("obs", dist.Normal(expected_sales, 0.1), obs=sales)

However, every time I try to incorporate a gaussian process prior in my model, it is either FAR to slow to train OR the gaussian process prior will depend on the input shape (training) and not work when running the model for prediction.

Here is one of many attempt to do that:

def _media_mix_model_time_varying_kernel(
    media_data: jnp.ndarray,
    target_data: jnp.ndarray,
    media_prior: jnp.ndarray,
    degrees_seasonality: int,
    frequency: int,
    transform_function: TransformFunction,
    custom_priors: MutableMapping[str, Prior],
    transform_kwargs: Optional[MutableMapping[str, Any]] = None,
    weekday_seasonality: bool = False,
    extra_features: Optional[jnp.ndarray] = None,
    time_varying=True, 
) -> None:
  """Media mix model.

  Args:
    media_data: Media data to be be used in the model.
    target_data: Target data for the model.
    media_prior: Cost prior for each of the media channels.
    degrees_seasonality: Number of degrees of seasonality to use.
    frequency: Frequency of the time span which was used to aggregate the data.
      Eg. if weekly data then frequency is 52.
    transform_function: Function to use to transform the media data in the
      model. Currently the following are supported: 'transform_adstock',
        'transform_carryover' and 'transform_hill_adstock'.
    custom_priors: The custom priors we want the model to take instead of the
      default ones. See our custom_priors documentation for details about the
      API and possible options.
    transform_kwargs: Any extra keyword arguments to pass to the transform
      function. For example the adstock function can take a boolean to noramlise
      output or not.
    weekday_seasonality: In case of daily data you can estimate a weekday (7)
      parameter.
    extra_features: Extra features data to include in the model.
  """
    
  default_priors = _get_default_priors()
  data_size = media_data.shape[0]
  n_channels = media_data.shape[1]
  geo_shape = (media_data.shape[2],) if media_data.ndim == 3 else ()
  n_geos = media_data.shape[2] if media_data.ndim == 3 else 1

  with numpyro.plate(name=f"{_INTERCEPT}_plate", size=n_geos):
    intercept = numpyro.sample(
        name=_INTERCEPT,
        fn=custom_priors.get(_INTERCEPT, default_priors[_INTERCEPT]))

  with numpyro.plate(name=f"{_SIGMA}_plate", size=n_geos):
    sigma = numpyro.sample(
        name=_SIGMA,
        fn=custom_priors.get(_SIGMA, default_priors[_SIGMA]))

  # TODO(): Force all geos to have the same trend sign.
  with numpyro.plate(name=f"{_COEF_TREND}_plate", size=n_geos):
    coef_trend = numpyro.sample(
        name=_COEF_TREND,
        fn=custom_priors.get(_COEF_TREND, default_priors[_COEF_TREND]))

  expo_trend = numpyro.sample(
      name=_EXPO_TREND,
      fn=custom_priors.get(
          _EXPO_TREND, default_priors[_EXPO_TREND]))

  with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_sin_cos_plate", size=2):
    with numpyro.plate(name=f"{_GAMMA_SEASONALITY}_plate",
                       size=degrees_seasonality):
      gamma_seasonality = numpyro.sample(
          name=_GAMMA_SEASONALITY,
          fn=custom_priors.get(
              _GAMMA_SEASONALITY, default_priors[_GAMMA_SEASONALITY]))

  if weekday_seasonality:
    with numpyro.plate(name=f"{_WEEKDAY}_plate", size=7):
      weekday = numpyro.sample(
          name=_WEEKDAY,
          fn=custom_priors.get(_WEEKDAY, default_priors[_WEEKDAY]))
    weekday_series = weekday[jnp.arange(data_size) % 7]
    # In case of daily data, number of lags should be 13*7.
    if transform_function == "carryover" and transform_kwargs and "number_lags" not in transform_kwargs:
      transform_kwargs["number_lags"] = 13 * 7
    elif transform_function == "carryover" and not transform_kwargs:
      transform_kwargs = {"number_lags": 13 * 7}

  media_transformed = numpyro.deterministic(
      name="media_transformed",
      value=transform_function(media_data,
                               custom_priors=custom_priors,
                               **transform_kwargs if transform_kwargs else {}))
    
  seasonality = media_transforms.calculate_seasonality(
      number_periods=data_size,
      degrees=degrees_seasonality,
      frequency=frequency,
      gamma_seasonality=gamma_seasonality)
  # For national model's case
  trend = jnp.arange(data_size)
  media_einsum = "tc, c -> t"  # t = time, c = channel
  coef_seasonality = 1

  if media_data.ndim == 3:  # For geo model's case
    trend = jnp.expand_dims(trend, axis=-1)
    seasonality = jnp.expand_dims(seasonality, axis=-1)
    media_einsum = "tcg, cg -> tg"  # t = time, c = channel, g = geo
    if weekday_seasonality:
      weekday_series = jnp.expand_dims(weekday_series, axis=-1)
    with numpyro.plate(name="seasonality_plate", size=n_geos):
      coef_seasonality = numpyro.sample(
          name=_COEF_SEASONALITY,
          fn=custom_priors.get(
              _COEF_SEASONALITY, default_priors[_COEF_SEASONALITY]))

  
  with numpyro.plate("channel_plate", n_channels):
      length_scales = numpyro.sample("length_scales", dist.LogNormal(0, 1))
      rbf_length_scales = numpyro.sample("rbf_length_scales", dist.LogNormal(0, 1))
      periodic_length_scales = numpyro.sample("periodic_length_scales", dist.LogNormal(0, 1))
      
      # 0. Basic basis kernel - Does not need the periodic and combined kernels
      #   # Create the kernel matrix for each channel and stack them together
      # Ks = [rbf_kernel(jnp.arange(data_size)[:, None], jnp.arange(data_size)[:, None], l) for l in length_scales]
      # K = jnp.stack(Ks, axis=0)  # Shape: (n_channels, data_size, data_size)
      # K += jnp.eye(data_size) * 1e-6  # for numerical stability
      # # ----------------------------------------------------------
    
      # 1. Full Gaussian process
      # Create the kernel matrix for each channel using combined kernel and stack them together
      # Ks = [combined_kernel(jnp.arange(data_size)[:, None], jnp.arange(data_size)[:, None], periodic_kernel, r, p) for r, p in zip(rbf_length_scales, periodic_length_scales)]
      # K = jnp.stack(Ks, axis=0)
      # K += jnp.eye(data_size) * 1e-6  # for numerical stability
      # # ----------------------------------------------------------
    
      # # 2. Sparse 
      # # Number of inducing points
      # num_inducing = 3  # or any number much smaller than data_size  # Select inducing points (you can choose these more wisely, this is just an example)
      # inducing_points = jnp.sort(jax.random.choice(jax.random.PRNGKey(0), jnp.arange(data_size), shape=(num_inducing,), replace=False))
    
      # Ks = [
      #     sparse_gaussian_process_kernel(jnp.arange(data_size), inducing_points, yearly_periodic_kernel, r, p) 
      #     for r, p in zip(rbf_length_scales, periodic_length_scales)
      # ]
      # K = jnp.stack(Ks, axis=0)
      # K += jnp.eye(data_size) * 1e-6  # for numerical stability
    
      K = get_kernel_matrix("sparse", data_size, rbf_length_scales, periodic_length_scales, period=frequency)

      # Step 2: Use the GP covariance matrix to sample the coefficients for each marketing channel over time
      with numpyro.plate("geo_media_plate", n_geos):
        coef_media = numpyro.sample(
            "coef_media",
            dist.MultivariateNormal(loc=jnp.zeros(data_size), covariance_matrix=K).to_event(1)
        )
    
      # Sum across channels
      print(f"{media_einsum=}, {media_data.shape=}, {coef_media.shape=}")
      media_spend_temp = media_data[:, :, :, None] * coef_media.transpose(3, 1, 0, 2)
    
      media_spend = jnp.sum(media_spend_temp, axis=1)
      media_spend = jnp.sum(media_spend, axis=2)

  # expo_trend is B(1, 1) so that the exponent on time is in [.5, 1.5].
  prediction = (
      intercept + coef_trend * trend ** expo_trend +
      seasonality * coef_seasonality +
      media_spend
  )
    
  if extra_features is not None:
    plate_prefixes = ("extra_feature",)
    extra_features_einsum = "tf, f -> t"  # t = time, f = feature
    extra_features_plates_shape = (extra_features.shape[1],)
    if extra_features.ndim == 3:
      plate_prefixes = ("extra_feature", "geo")
      extra_features_einsum = "tfg, fg -> tg"  # t = time, f = feature, g = geo
      extra_features_plates_shape = (extra_features.shape[1], *geo_shape)
    with numpyro.plate_stack(plate_prefixes,
                             sizes=extra_features_plates_shape):
      coef_extra_features = numpyro.sample(
          name=_COEF_EXTRA_FEATURES,
          fn=custom_priors.get(
              _COEF_EXTRA_FEATURES, default_priors[_COEF_EXTRA_FEATURES]))
    extra_features_effect = jnp.einsum(extra_features_einsum,
                                       extra_features,
                                       coef_extra_features)
    prediction += extra_features_effect

  if weekday_seasonality:
    prediction += weekday_series
  mu = numpyro.deterministic(name="mu", value=prediction)

  numpyro.sample(
      name="target", fn=dist.Normal(loc=mu, scale=sigma), obs=target_data)

hi @markussagen

thanks for the question. i think you’re more likely to get a useful answer if you ask a targeted question that doesn’t involve hundreds of lines of code => minimal working example