If I run this code
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import numpyro.infer as infer
from numpyro.contrib.control_flow import scan
x = jnp.asarray(np.linspace(-3, 3, 10) * 1.0)
y = jnp.asarray(2 * x + 1 + np.random.randn(len(x)) * 0.01)
def model(x, y, future=0):
T = len(y)
alpha = numpyro.sample("alpha", dist.Normal(0, 3))
beta = numpyro.sample("beta", dist.Normal(0, 3))
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
def transition_fn(carry, t):
mu = alpha + beta * x[t]
pred = numpyro.sample("pred", dist.Normal(mu, sigma))
return None, pred
with numpyro.handlers.condition(data={"pred": y}):
_, preds = scan(transition_fn, None, jnp.arange(T + future))
return preds
mcmc = infer.MCMC(
infer.NUTS(model),
num_chains=1,
num_samples=500,
num_warmup=500,
)
mcmc.run(jax.random.PRNGKey(0), x=x, y=y[:-2])
pred = infer.Predictive(model, mcmc.get_samples(), return_sites=["pred"])
preds = pred(jax.random.PRNGKey(0), x=x, y=y[:-2], future=2)
preds["pred"].shape
then I get (500, 10)
.
If I try to re-write this without using condition
, i.e.:
def model(x, y, future=0):
T = len(y)
alpha = numpyro.sample("alpha", dist.Normal(0, 3))
beta = numpyro.sample("beta", dist.Normal(0, 3))
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
with numpyro.handlers.condition(data={"pred": y}):
mu = alpha + beta * x[: T + future]
preds = numpyro.sample("pred", dist.Normal(mu, sigma))
return preds
then I get (500, 8)
.
Why is that?
(I’m aware that I don’t strictly need condition
in this example, just using obs=
would be enough, I’m just trying to make sense of condition
for a time series problem)