Iterative predictions

Hello,

I have a numpyro model based on time series data where each prediction depends on the previous one (i.e. the prediction at t_i is an input feature at time t_{i+1}). To be more precise, we predict the number of sold goods pr day and the number of remaining goods is a feature.

Is there a method to speed up the computation of future data using Jax / Numpyro compared to my current approach which is basically a (slow) Python for-loop.

I more or less do (in Python pseudo code):

f_predictive = Predictive(
    model,
    guide=guide,
    params=svi_result.params,
    num_samples=n_samples,
    return_sites=return_sites,
    parallel=parallel,
)

N = 100
for i in range(N):
    posterior = f_predictive(
        rng_key,
        x=x,
        **kwargs,
    )
    mean_obs = posterior["obs"].mean(axis=0)

    result[:, i] = mean_obs
    x[:, -1] = max_goods - mean_obs

Let me know if you need a proper MWE, but now I just need guidance on the right direction to look with Jax / Numpyro.

Thanks a lot!

can you please be more specific? where is there i dependence besides result indexing?

I update the input x. So the last column in x is basically the number of available goods at time t and as time move on, we sell more and more, decreasing this value. And the issue is that the model is trained on the number of available goods, so if we want to predict the next 10 time steps, we have to:

  1. First compute the predicted number of sold goods at time t (the posterior)
  2. Subtract this number in the last column of x
  3. Compute the new number of sold goods at time t+1
  4. Subtract this number from x
  5. and so on

So I cannot compute the full input matrix (or dataframe) x and just compute the predictions at once, because they all depend on the previous predictions.

I am sorry I wasn’t more clear before, I hope this helped clear it up a bit :slight_smile:

but you only run SVI inference once?

This almost sounds like some kind of autoregressive process, which you can model with multivariate distributions, the specific CAR distribution, or by using an external package like tinygp. Could you write the model in mathematical form, even briefly?

Alternately, the svi and MCMC .run() triggers take the model inputs as an argument, so you could call them recursively in that way.

Yeah, I just run / train the SVI once. My issue is solely with predictions, not training, since I can easily fit the model back in time.

Thanks for your reply!

Yeah, so we have different products with certain features, e.g. product category, weight, and so on. These are “constant” in the sense that they do not change over time. Then we also have the number of sold products at time t as a feature, which obviously change over time (and is the last column in the X described above). Historically, we know the price of the product and the amount sold.

We assume a specific demand function, eg. an exponential demand: D(P|D_0, \beta) = D_0 \exp(-\beta P) where D is the demand (i.e. sold products) and P is the price.

The idea is then to learn the demand parameters (D_0, \beta) from the data (X) by using a Bayesian Neural Network (BNN). The model is more or less like below:

def exponential_demand_poisson(
    x,
    price=None,
    y=None,
    compute_optimal_price=True,
):
    neural_network = NeuralNetwork(output_dim=2, hidden_layers=2)
    bayesian_neural_network = random_flax_module(
        "nn",
        neural_network,
        prior=(lambda name, shape: dist.Cauchy() if name == "bias" else dist.Normal()),
        input_shape=x.shape,
    )

    N = x.shape[0]
    with numpyro.plate("observations", N):

        log_D0, log_beta = bayesian_neural_network(x)
        D0 = numpyro.deterministic("D0", jnp.exp(log_D0))
        beta = numpyro.deterministic("beta", jnp.exp(log_beta))

        if price is not None:
            D = demand_exponential(D0, beta, price)
            D = numpyro.deterministic("D", D + 0.01)  # constant to avoid zero demand
            numpyro.sample("demand", dist.Poisson(D), obs=y)

        if compute_optimal_price:
            optimal_price = numpyro.deterministic("optimal_price", optimal_price_exponential(beta))

We train the BNN using SVI using both the historical X and prices.

Now, given a specific price (that we want to vary) at time t, we predict the posterior demand at time t+1, i.e. D_0, \beta = BNN(X) and then use the demand function D(P|D_0, \beta) to predict the number of sold products at time t+1.

The issue is that we want to not just predict for time t+1 but let’s say t+100. So currently I would loop through the model predictions in a Python for loop, which predicts demand and changes the last column of X accordingly iteratively.

I hope this made the model a bit better described.

@HughMcDougallAstro, what do you mean by the SVI triggers take the model inputs as argument?

Right now it takes around 1-2 seconds to make a single posterior prediction, which in itself is fine, however, with the naive Python for loop a t+100 prediction takes a couple of minutes which is a lot slower than our previous Tensorflow models.

i don’t think i really follow the data flow or motivation but have you tried replacing the for loop with jax.lax.scan?

Hm, how do you do that? I assume you have to jit-compile your function first and even a simple example like the one below fails:

import jax

@jax.jit
def foo1(
    demand_model,
    guide,
    params,
    n_samples,
    return_sites,
    parallel,
    rng_key,
    x,
    price,
    **kwargs
):
    f_predictive = Predictive(
        demand_model,
        guide=guide,
        params=params,
        num_samples=n_samples,
        return_sites=return_sites,
        parallel=parallel,
    )

    posterior = f_predictive(
        rng_key,
        x=x,
        price=None,
        **kwargs,
    )

    return posterior

with the error:

TypeError: Cannot interpret value of type <class 'function'> as an abstract array; it does not have a dtype attribute

If I instead move the definition of f_predictive outside of the function (since this should also be outside of the jax.lax.scan loop):

f_predictive = Predictive(
        demand_model,
        guide=guide,
        params=params,
        num_samples=n_samples,
        return_sites=return_sites,
        parallel=parallel,
    )

@jax.jit
def foo2(
    f_predictive,
    rng_key,
    x,
    price,
    **kwargs
):

    posterior = f_predictive(
        rng_key,
        x=x,
        price=None,
        **kwargs,
    )

    return posterior

foo2(f_predictive, rng_key, x, price.values, **kwargs)

I get a similar error:

TypeError: Cannot interpret value of type <class 'numpyro.infer.util.Predictive'> as an abstract array; it does not have a dtype attribute

Thanks a lot for helping, @martinjankowiak !

Okay, I have now tried to make a MWP such that it is hopefully more clear what the goal of this post is :slight_smile:

First we import functions:

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from flax import linen as nn
from jax.random import PRNGKey
from numpyro.contrib.module import random_flax_module
from numpyro.infer import (
    SVI,
    Predictive,
    TraceMeanField_ELBO,
    autoguide,
    init_to_feasible,
)
from optax import adam, exponential_decay

Then we define the demand function (demand_exponential) and the full Bayesian model (exponential_demand_poisson) using the Flax neural network:

def demand_exponential(D0, beta, price):
    return D0 * jnp.exp(-beta * price)

class NeuralNetwork(nn.Module):
    output_dim: int
    hidden_layers: list

    @nn.compact
    def __call__(self, x):
        for n_units in self.hidden_layers:
            x = nn.Dense(n_units)(x)
            x = nn.relu(x)
        out = nn.Dense(self.output_dim)(x)
        return out.squeeze().T  # allows for unpacking for both 1 and multiple outputs

def exponential_demand_poisson(
    x,
    price=None,
    y=None,
):
    neural_network = NeuralNetwork(output_dim=2, hidden_layers=[32, 16])
    bayesian_neural_network = random_flax_module(
        "nn",
        neural_network,
        prior=(lambda name, shape: dist.Cauchy() if name == "bias" else dist.Normal()),
        input_shape=x.shape,
    )

    N = x.shape[0]
    with numpyro.plate("observations", N):
        log_D0, log_beta = bayesian_neural_network(x)
        D0 = numpyro.deterministic("D0", jnp.exp(log_D0))
        beta = numpyro.deterministic("beta", jnp.exp(log_beta))

        if price is not None:
            D = demand_exponential(D0, beta, price)
            D = numpyro.deterministic("D", D + 0.01)  # constant to avoid zero demand
            numpyro.sample("demand", dist.Poisson(D), obs=y)

Now we generate som toy data:

def generate_data(N):
    np.random.seed(42)
    N = 100
    x = np.hstack(
        [
            np.random.normal(loc=5, size=(N, 3)),
            np.random.randint(0, 10, size=(N, 1)),
        ]
    )
    price = np.random.uniform(0, 10, size=N)

    d0 = x[:, 0] ** 2 + x[:, 1] + 10 * x[:, -1]
    beta = 0.1
    d = demand_exponential(d0, beta, price)
    y = np.random.poisson(d)

    return x, y, price

x, y, price = generate_data(N=1000)

and fit it using SVI:

rng_key = PRNGKey(0)
n_fit_iterations = 10_000
n_samples = 1000

guide = autoguide.AutoNormal(
    exponential_demand_poisson,
    init_loc_fn=init_to_feasible,
)

learning_rate = 0.01
decay_rate = 0.01
optimizer = adam(
    exponential_decay(
        learning_rate,
        n_fit_iterations,
        decay_rate,
    )
)

svi = SVI(
    exponential_demand_poisson,
    guide,
    optimizer,
    TraceMeanField_ELBO(),
)

svi_result = svi.run(
    rng_key,
    n_fit_iterations,
    progress_bar=True,
    x=x,
    price=price,
    y=y,
)

Now we arrive at the goal. I want to make iterative_predict faster by hopefully utilize jax or any Numpyro-related functions:

f_predictive = Predictive(
    exponential_demand_poisson,
    guide=guide,
    params=svi_result.params,
    num_samples=n_samples,
    return_sites=["D0", "beta", "demand"],
    parallel=False,
)

def iterative_predict(x, M=10):
    demands = np.zeros((len(x), M))
    x = x.copy()

    for i in range(M):
        posterior = f_predictive(
            rng_key,
            x=x,
            price=price,
            y=None,
        )

        mean_demand = posterior["demand"].mean(axis=0)
        x[:, -1] = mean_demand / 10
        demands[:, i] = mean_demand

    return demands

So the idea is that we use the trained Bayesian neural network in f_predictive and iteratively predict on the data x where the x itself changes according to the output of the previous step.

I hope this makes sense. Btw, thanks a lot for the help so far, @martinjankowiak and @HughMcDougallAstro !

posterior = f_predictive(
            rng_key,
            x=x,
            price=price,
            y=None,
        )

this function can be jitted to make the iteration faster. You can use lax.scan but a for loop is enough I guess.

1 Like

That was indeed the trick, @fehiepsi, thanks a lot! Now we can run it for 100 iterations in less than a second, so that’s perfect.