If I run this code

```
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import numpyro.infer as infer
from numpyro.contrib.control_flow import scan
x = jnp.asarray(np.linspace(-3, 3, 10) * 1.0)
y = jnp.asarray(2 * x + 1 + np.random.randn(len(x)) * 0.01)
def model(x, y, future=0):
T = len(y)
alpha = numpyro.sample("alpha", dist.Normal(0, 3))
beta = numpyro.sample("beta", dist.Normal(0, 3))
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
def transition_fn(carry, t):
mu = alpha + beta * x[t]
pred = numpyro.sample("pred", dist.Normal(mu, sigma))
return None, pred
with numpyro.handlers.condition(data={"pred": y}):
_, preds = scan(transition_fn, None, jnp.arange(T + future))
return preds
mcmc = infer.MCMC(
infer.NUTS(model),
num_chains=1,
num_samples=500,
num_warmup=500,
)
mcmc.run(jax.random.PRNGKey(0), x=x, y=y[:-2])
pred = infer.Predictive(model, mcmc.get_samples(), return_sites=["pred"])
preds = pred(jax.random.PRNGKey(0), x=x, y=y[:-2], future=2)
preds["pred"].shape
```

then I get `(500, 10)`

.

If I try to re-write this without using `condition`

, i.e.:

```
def model(x, y, future=0):
T = len(y)
alpha = numpyro.sample("alpha", dist.Normal(0, 3))
beta = numpyro.sample("beta", dist.Normal(0, 3))
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
with numpyro.handlers.condition(data={"pred": y}):
mu = alpha + beta * x[: T + future]
preds = numpyro.sample("pred", dist.Normal(mu, sigma))
return preds
```

then I get `(500, 8)`

.

Why is that?

(Iâ€™m aware that I donâ€™t strictly need `condition`

in this example, just using `obs=`

would be enough, Iâ€™m just trying to make sense of `condition`

for a time series problem)