Scan along a plate?

Is there a way to use scan+ condition along a specific plate? I have a hierarchical random walk model and it seems like I have to transpose the array so that time is along the first dimension of the array for condition+scan to work.

Any ways around this? I’d love to have a pattern that looks like this (I know you cant have scan inside of a plate, but I think it illustrates what Im looking for)

with time_plate:
     _, ys = scan(**kwargs)

heres the model:

def hierarchical_rw_model(y=None, future=0):
    N = 0 if y is None else y.shape[0]
    T = 0 if y is None else y.shape[1]
    level_init = 0 if y is None else y[:,0]
    
    # Global Vars
    mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0,1))
    sig_alpha = numpyro.sample("sig_alpha", dist.Exponential(2.5))
    mu_beta = numpyro.sample("mu_beta", dist.Normal(0,1))
    sig_beta = numpyro.sample("sig_beta", dist.Exponential(1))
    mu_sigma = numpyro.sample("mu_sigma", dist.Normal(0,1))
    
    # Time Series Level vars
    with numpyro.plate("time_series", N):
        alpha = numpyro.sample("alpha", dist.Normal(mu_alpha,sig_alpha))
        beta = numpyro.sample("beta", dist.Normal(mu_beta, sig_beta))
        sigma = numpyro.sample("sigma", dist.HalfNormal(jnp.exp(mu_sigma)))
    
    def transition_fn(y_t_minus_1, t):
        exp_val = y_t_minus_1 + alpha + beta*t/365.25
        # Observational model
        y_ = numpyro.sample("y", dist.Normal(exp_val, sigma))
        
        # Recursive update
        y_t_minus_1 = y_
        return y_t_minus_1, y_
    
    # It seems like condition+scan may only work along the first axis,
    # so transposing y
    with numpyro.handlers.condition(data={"y": y[:,1:].T}):
        _, ys = scan(
            f=transition_fn, 
            init=level_init, 
            xs=jnp.arange(1, T + future)
        )
    if future > 0:
        numpyro.deterministic("y_forecast", ys[-future:])

Full reproducible code:

np.random.seed(SEED)
T = 800
h = 200
N = 100


mu_alpha = -0.15
log_mu_sigma = np.log(5)
mu_beta = 0.5


alpha = np.random.normal(mu_alpha, 0.35, size=N)
sigma = np.random.gamma(100, np.exp(log_mu_sigma)/100, size=N)
beta = np.random.normal(mu_beta, 0.25, size=N)

trend = beta[:,None]*np.arange(T+h)/365.25
rw_with_drift = np.random.normal(alpha[:,None]+trend, sigma[:,None]).cumsum(-1) 
y = jnp.array( rw_with_drift  )



def hierarchical_rw_model(y=None, future=0):
    N = 0 if y is None else y.shape[0]
    T = 0 if y is None else y.shape[1]
    level_init = 0 if y is None else y[:,0]
    
    # Global Vars
    mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0,1))
    sig_alpha = numpyro.sample("sig_alpha", dist.Exponential(2.5))
    mu_beta = numpyro.sample("mu_beta", dist.Normal(0,1))
    sig_beta = numpyro.sample("sig_beta", dist.Exponential(1))
    mu_sigma = numpyro.sample("mu_sigma", dist.Normal(0,1))
    
    # Time Series Level vars
    with numpyro.plate("time_series", N):
        alpha = numpyro.sample("alpha", dist.Normal(mu_alpha,sig_alpha))
        beta = numpyro.sample("beta", dist.Normal(mu_beta, sig_beta))
        sigma = numpyro.sample("sigma", dist.HalfNormal(jnp.exp(mu_sigma)))
    
    def transition_fn(y_t_minus_1, t):
        exp_val = y_t_minus_1 + alpha + beta*t/365.25
        # Observational model
        y_ = numpyro.sample("y", dist.Normal(exp_val, sigma))
        
        # Recursive update
        y_t_minus_1 = y_
        return y_t_minus_1, y_
    
    # It seems like condition+scan may only work along the first axis,
    # so transposing y
    with numpyro.handlers.condition(data={"y": y[:,1:].T}):
        _, ys = scan(
            f=transition_fn, 
            init=level_init, 
            xs=jnp.arange(1, T + future)
        )
    if future > 0:
        numpyro.deterministic("y_forecast", ys[-future:])
    
kernel = NUTS(hierarchical_rw_model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000, num_chains=2)
mcmc.run(random.PRNGKey(0), y=y[...,:T])
samples = mcmc.get_samples()