Scan along a plate?

Is there a way to use scan+ condition along a specific plate? I have a hierarchical random walk model and it seems like I have to transpose the array so that time is along the first dimension of the array for condition+scan to work.

Any ways around this? I’d love to have a pattern that looks like this (I know you cant have scan inside of a plate, but I think it illustrates what Im looking for)

with time_plate:
     _, ys = scan(**kwargs)

heres the model:

def hierarchical_rw_model(y=None, future=0):
    N = 0 if y is None else y.shape[0]
    T = 0 if y is None else y.shape[1]
    level_init = 0 if y is None else y[:,0]
    
    # Global Vars
    mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0,1))
    sig_alpha = numpyro.sample("sig_alpha", dist.Exponential(2.5))
    mu_beta = numpyro.sample("mu_beta", dist.Normal(0,1))
    sig_beta = numpyro.sample("sig_beta", dist.Exponential(1))
    mu_sigma = numpyro.sample("mu_sigma", dist.Normal(0,1))
    
    # Time Series Level vars
    with numpyro.plate("time_series", N):
        alpha = numpyro.sample("alpha", dist.Normal(mu_alpha,sig_alpha))
        beta = numpyro.sample("beta", dist.Normal(mu_beta, sig_beta))
        sigma = numpyro.sample("sigma", dist.HalfNormal(jnp.exp(mu_sigma)))
    
    def transition_fn(y_t_minus_1, t):
        exp_val = y_t_minus_1 + alpha + beta*t/365.25
        # Observational model
        y_ = numpyro.sample("y", dist.Normal(exp_val, sigma))
        
        # Recursive update
        y_t_minus_1 = y_
        return y_t_minus_1, y_
    
    # It seems like condition+scan may only work along the first axis,
    # so transposing y
    with numpyro.handlers.condition(data={"y": y[:,1:].T}):
        _, ys = scan(
            f=transition_fn, 
            init=level_init, 
            xs=jnp.arange(1, T + future)
        )
    if future > 0:
        numpyro.deterministic("y_forecast", ys[-future:])

Full reproducible code:

np.random.seed(SEED)
T = 800
h = 200
N = 100


mu_alpha = -0.15
log_mu_sigma = np.log(5)
mu_beta = 0.5


alpha = np.random.normal(mu_alpha, 0.35, size=N)
sigma = np.random.gamma(100, np.exp(log_mu_sigma)/100, size=N)
beta = np.random.normal(mu_beta, 0.25, size=N)

trend = beta[:,None]*np.arange(T+h)/365.25
rw_with_drift = np.random.normal(alpha[:,None]+trend, sigma[:,None]).cumsum(-1) 
y = jnp.array( rw_with_drift  )



def hierarchical_rw_model(y=None, future=0):
    N = 0 if y is None else y.shape[0]
    T = 0 if y is None else y.shape[1]
    level_init = 0 if y is None else y[:,0]
    
    # Global Vars
    mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0,1))
    sig_alpha = numpyro.sample("sig_alpha", dist.Exponential(2.5))
    mu_beta = numpyro.sample("mu_beta", dist.Normal(0,1))
    sig_beta = numpyro.sample("sig_beta", dist.Exponential(1))
    mu_sigma = numpyro.sample("mu_sigma", dist.Normal(0,1))
    
    # Time Series Level vars
    with numpyro.plate("time_series", N):
        alpha = numpyro.sample("alpha", dist.Normal(mu_alpha,sig_alpha))
        beta = numpyro.sample("beta", dist.Normal(mu_beta, sig_beta))
        sigma = numpyro.sample("sigma", dist.HalfNormal(jnp.exp(mu_sigma)))
    
    def transition_fn(y_t_minus_1, t):
        exp_val = y_t_minus_1 + alpha + beta*t/365.25
        # Observational model
        y_ = numpyro.sample("y", dist.Normal(exp_val, sigma))
        
        # Recursive update
        y_t_minus_1 = y_
        return y_t_minus_1, y_
    
    # It seems like condition+scan may only work along the first axis,
    # so transposing y
    with numpyro.handlers.condition(data={"y": y[:,1:].T}):
        _, ys = scan(
            f=transition_fn, 
            init=level_init, 
            xs=jnp.arange(1, T + future)
        )
    if future > 0:
        numpyro.deterministic("y_forecast", ys[-future:])
    
kernel = NUTS(hierarchical_rw_model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000, num_chains=2)
mcmc.run(random.PRNGKey(0), y=y[...,:T])
samples = mcmc.get_samples()
1 Like

Do you know how to do this if each of your random walks have a different number of data points?

Do you know how to do this if each of your random walks have a different number of data points?

I think you could probably do something like this with numpyro.handlers.mask and pad with nulls Time varying dimension measurement in gaussian hmm

if you search handlers.mask in this forums search there seem to be some similar examples that might help

1 Like

And to answer the original question, this was a misunderstand of plates from my younger self. What I really wanted to do was scan and condition along a specific axis, but numpyro’s scan is constrained to work over the leading access

Wait, the original code doesn’t work. It runs but the output is nonsense…

If you mean this example (a random walk), I have a notebook here with more examples that work case_studies/Time Series with Numpyro - Random Walk Intro.ipynb at main · kylejcaron/case_studies · GitHub

If you meant the HMM I linked, I was just referring to the pattern they use of poutine.mask (numpyros equivalent is handlers.mask) - cant speak for the rest of their code so sorry if that confused you.

Here are better examples I just found from searching for handlers.mask and poutine.mask Variable length sequences and masking Example: Hidden Markov Models — Pyro Tutorials 1.9.1 documentation

Oh, I was using az.plot_trace() directly on the samples which gives strange results. My apologies. Thank you so much for your help! It really made my day better.

1 Like