Scan along a plate?

Is there a way to use scan+ condition along a specific plate? I have a hierarchical random walk model and it seems like I have to transpose the array so that time is along the first dimension of the array for condition+scan to work.

Any ways around this? I’d love to have a pattern that looks like this (I know you cant have scan inside of a plate, but I think it illustrates what Im looking for)

``````with time_plate:
_, ys = scan(**kwargs)
``````

heres the model:

``````def hierarchical_rw_model(y=None, future=0):
N = 0 if y is None else y.shape[0]
T = 0 if y is None else y.shape[1]
level_init = 0 if y is None else y[:,0]

# Global Vars
mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0,1))
sig_alpha = numpyro.sample("sig_alpha", dist.Exponential(2.5))
mu_beta = numpyro.sample("mu_beta", dist.Normal(0,1))
sig_beta = numpyro.sample("sig_beta", dist.Exponential(1))
mu_sigma = numpyro.sample("mu_sigma", dist.Normal(0,1))

# Time Series Level vars
with numpyro.plate("time_series", N):
alpha = numpyro.sample("alpha", dist.Normal(mu_alpha,sig_alpha))
beta = numpyro.sample("beta", dist.Normal(mu_beta, sig_beta))
sigma = numpyro.sample("sigma", dist.HalfNormal(jnp.exp(mu_sigma)))

def transition_fn(y_t_minus_1, t):
exp_val = y_t_minus_1 + alpha + beta*t/365.25
# Observational model
y_ = numpyro.sample("y", dist.Normal(exp_val, sigma))

# Recursive update
y_t_minus_1 = y_
return y_t_minus_1, y_

# It seems like condition+scan may only work along the first axis,
# so transposing y
with numpyro.handlers.condition(data={"y": y[:,1:].T}):
_, ys = scan(
f=transition_fn,
init=level_init,
xs=jnp.arange(1, T + future)
)
if future > 0:
numpyro.deterministic("y_forecast", ys[-future:])
``````

Full reproducible code:

``````np.random.seed(SEED)
T = 800
h = 200
N = 100

mu_alpha = -0.15
log_mu_sigma = np.log(5)
mu_beta = 0.5

alpha = np.random.normal(mu_alpha, 0.35, size=N)
sigma = np.random.gamma(100, np.exp(log_mu_sigma)/100, size=N)
beta = np.random.normal(mu_beta, 0.25, size=N)

trend = beta[:,None]*np.arange(T+h)/365.25
rw_with_drift = np.random.normal(alpha[:,None]+trend, sigma[:,None]).cumsum(-1)
y = jnp.array( rw_with_drift  )

def hierarchical_rw_model(y=None, future=0):
N = 0 if y is None else y.shape[0]
T = 0 if y is None else y.shape[1]
level_init = 0 if y is None else y[:,0]

# Global Vars
mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0,1))
sig_alpha = numpyro.sample("sig_alpha", dist.Exponential(2.5))
mu_beta = numpyro.sample("mu_beta", dist.Normal(0,1))
sig_beta = numpyro.sample("sig_beta", dist.Exponential(1))
mu_sigma = numpyro.sample("mu_sigma", dist.Normal(0,1))

# Time Series Level vars
with numpyro.plate("time_series", N):
alpha = numpyro.sample("alpha", dist.Normal(mu_alpha,sig_alpha))
beta = numpyro.sample("beta", dist.Normal(mu_beta, sig_beta))
sigma = numpyro.sample("sigma", dist.HalfNormal(jnp.exp(mu_sigma)))

def transition_fn(y_t_minus_1, t):
exp_val = y_t_minus_1 + alpha + beta*t/365.25
# Observational model
y_ = numpyro.sample("y", dist.Normal(exp_val, sigma))

# Recursive update
y_t_minus_1 = y_
return y_t_minus_1, y_

# It seems like condition+scan may only work along the first axis,
# so transposing y
with numpyro.handlers.condition(data={"y": y[:,1:].T}):
_, ys = scan(
f=transition_fn,
init=level_init,
xs=jnp.arange(1, T + future)
)
if future > 0:
numpyro.deterministic("y_forecast", ys[-future:])

kernel = NUTS(hierarchical_rw_model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000, num_chains=2)
mcmc.run(random.PRNGKey(0), y=y[...,:T])
samples = mcmc.get_samples()
``````