Condition, without using scan

If I run this code

import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import numpyro.infer as infer
from numpyro.contrib.control_flow import scan

x = jnp.asarray(np.linspace(-3, 3, 10) * 1.0)
y = jnp.asarray(2 * x + 1 + np.random.randn(len(x)) * 0.01)


def model(x, y, future=0):
    T = len(y)
    alpha = numpyro.sample("alpha", dist.Normal(0, 3))
    beta = numpyro.sample("beta", dist.Normal(0, 3))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1))

    def transition_fn(carry, t):
        mu = alpha + beta * x[t]
        pred = numpyro.sample("pred", dist.Normal(mu, sigma))
        return None, pred

    with numpyro.handlers.condition(data={"pred": y}):
        _, preds = scan(transition_fn, None, jnp.arange(T + future))
    return preds


mcmc = infer.MCMC(
    infer.NUTS(model),
    num_chains=1,
    num_samples=500,
    num_warmup=500,
)
mcmc.run(jax.random.PRNGKey(0), x=x, y=y[:-2])

pred = infer.Predictive(model, mcmc.get_samples(), return_sites=["pred"])
preds = pred(jax.random.PRNGKey(0), x=x, y=y[:-2], future=2)

preds["pred"].shape

then I get (500, 10).

If I try to re-write this without using condition, i.e.:

def model(x, y, future=0):
    T = len(y)
    alpha = numpyro.sample("alpha", dist.Normal(0, 3))
    beta = numpyro.sample("beta", dist.Normal(0, 3))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1))

    with numpyro.handlers.condition(data={"pred": y}):
            mu = alpha + beta * x[: T + future]
            preds = numpyro.sample("pred", dist.Normal(mu, sigma))
    return preds

then I get (500, 8).

Why is that?

(I’m aware that I don’t strictly need condition in this example, just using obs= would be enough, I’m just trying to make sense of condition for a time series problem)

The feature of using condition handler for scan (in forecasting problems) might be confusing. How about writing two loops, one for likelihood and one for forecasting? That way, your code of using scan and not using scan will give you consistent results.

About why using condition with scan can do forecasting, the reason that we supported that special case of using condition with scan is: timeseries problems are popular and we think that having a single loop for both likelihood and forecasting might be convenient (it is a “special” usage case that we supported for the sake of convenience).

For non-forecasting problems (i.e. models that not use scan), using condition is equivalent to setting obs=.... So your preds should have the same value as y.

1 Like