Okay, I have now tried to make a MWP such that it is hopefully more clear what the goal of this post is 
First we import functions:
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from flax import linen as nn
from jax.random import PRNGKey
from numpyro.contrib.module import random_flax_module
from numpyro.infer import (
SVI,
Predictive,
TraceMeanField_ELBO,
autoguide,
init_to_feasible,
)
from optax import adam, exponential_decay
Then we define the demand function (demand_exponential
) and the full Bayesian model (exponential_demand_poisson
) using the Flax neural network:
def demand_exponential(D0, beta, price):
return D0 * jnp.exp(-beta * price)
class NeuralNetwork(nn.Module):
output_dim: int
hidden_layers: list
@nn.compact
def __call__(self, x):
for n_units in self.hidden_layers:
x = nn.Dense(n_units)(x)
x = nn.relu(x)
out = nn.Dense(self.output_dim)(x)
return out.squeeze().T # allows for unpacking for both 1 and multiple outputs
def exponential_demand_poisson(
x,
price=None,
y=None,
):
neural_network = NeuralNetwork(output_dim=2, hidden_layers=[32, 16])
bayesian_neural_network = random_flax_module(
"nn",
neural_network,
prior=(lambda name, shape: dist.Cauchy() if name == "bias" else dist.Normal()),
input_shape=x.shape,
)
N = x.shape[0]
with numpyro.plate("observations", N):
log_D0, log_beta = bayesian_neural_network(x)
D0 = numpyro.deterministic("D0", jnp.exp(log_D0))
beta = numpyro.deterministic("beta", jnp.exp(log_beta))
if price is not None:
D = demand_exponential(D0, beta, price)
D = numpyro.deterministic("D", D + 0.01) # constant to avoid zero demand
numpyro.sample("demand", dist.Poisson(D), obs=y)
Now we generate som toy data:
def generate_data(N):
np.random.seed(42)
N = 100
x = np.hstack(
[
np.random.normal(loc=5, size=(N, 3)),
np.random.randint(0, 10, size=(N, 1)),
]
)
price = np.random.uniform(0, 10, size=N)
d0 = x[:, 0] ** 2 + x[:, 1] + 10 * x[:, -1]
beta = 0.1
d = demand_exponential(d0, beta, price)
y = np.random.poisson(d)
return x, y, price
x, y, price = generate_data(N=1000)
and fit it using SVI:
rng_key = PRNGKey(0)
n_fit_iterations = 10_000
n_samples = 1000
guide = autoguide.AutoNormal(
exponential_demand_poisson,
init_loc_fn=init_to_feasible,
)
learning_rate = 0.01
decay_rate = 0.01
optimizer = adam(
exponential_decay(
learning_rate,
n_fit_iterations,
decay_rate,
)
)
svi = SVI(
exponential_demand_poisson,
guide,
optimizer,
TraceMeanField_ELBO(),
)
svi_result = svi.run(
rng_key,
n_fit_iterations,
progress_bar=True,
x=x,
price=price,
y=y,
)
Now we arrive at the goal. I want to make iterative_predict
faster by hopefully utilize jax or any Numpyro-related functions:
f_predictive = Predictive(
exponential_demand_poisson,
guide=guide,
params=svi_result.params,
num_samples=n_samples,
return_sites=["D0", "beta", "demand"],
parallel=False,
)
def iterative_predict(x, M=10):
demands = np.zeros((len(x), M))
x = x.copy()
for i in range(M):
posterior = f_predictive(
rng_key,
x=x,
price=price,
y=None,
)
mean_demand = posterior["demand"].mean(axis=0)
x[:, -1] = mean_demand / 10
demands[:, i] = mean_demand
return demands
So the idea is that we use the trained Bayesian neural network in f_predictive
and iteratively predict on the data x
where the x
itself changes according to the output of the previous step.
I hope this makes sense. Btw, thanks a lot for the help so far, @martinjankowiak and @HughMcDougallAstro !