Hierarchical Time Series Modeling for Data with Different Lengths

Hello,

This code was slightly modified from this notebook, specifically under the Multivariate Time Series Forecasting section. It is modeling random walks with drift and trend. The coefficients of the timeseries have a hierarchical structure. Before we get started, this is what I am importing:

import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
idx = pd.IndexSlice
# import jax
from jax import random
import jax.numpy as jnp
import arviz as az
import numpyro as ny
from numpyro.contrib.control_flow import scan
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
ny.set_host_device_count(4)

I am trying to understand how to do hierarchical timeseries analysis on data where each of the timeseries have different lengths. I am able to fit the model when all the timeseries have the same length. That code is here:

# Creating the data
np.random.seed(42)
T = 800
h = 200
N = 10

mu_alpha = -0.15
log_mu_sigma = np.log(5)
mu_beta = 0.5

alpha = np.random.normal(mu_alpha, 0.35, size=N)
sigma = np.random.gamma(100, np.exp(log_mu_sigma)/100, size=N)
beta = np.random.normal(mu_beta, 0.25, size=N)

trend = beta[:,None]*np.arange(T+h)/365.25
rw_with_drift = np.random.normal(alpha[:,None]+trend, sigma[:,None]).cumsum(-1) 

y = jnp.array( rw_with_drift  )

fig, ax = plt.subplots()
ax.plot(np.arange(T+h), y.T, alpha=0.25)
ax.set(xlabel='Time', ylabel='Y',title='RW with Drift and Trend')
ax.axvline(T, color='k', ls='--', label='Training Period Ends')
ax.legend()
plt.show()

# Defining the model
def hierarchical_rw_model(y=None, future=0):
	N = 0 if y is None else y.shape[0]
	T = 0 if y is None else y.shape[1]
	level_init = 0 if y is None else y[:,0]
	
	# Global Vars
	mu_alpha = ny.sample("mu_alpha", dist.Normal(0,1))
	sig_alpha = ny.sample("sig_alpha", dist.Exponential(2.5))

	mu_beta = ny.sample("mu_beta", dist.Normal(0,1))
	sig_beta = ny.sample("sig_beta", dist.Exponential(1))

	mu_sigma = ny.sample("mu_sigma", dist.Normal(0,1))
	test = y[:,1:].T
	
	# Time Series Level vars
	with ny.plate("time_series", N):
		alpha = ny.sample("alpha", dist.Normal(mu_alpha,sig_alpha))
		beta = ny.sample("beta", dist.Normal(mu_beta, sig_beta))
		sigma = ny.sample("sigma", dist.HalfNormal(jnp.exp(mu_sigma)))
	
	def transition_fn(y_t_minus_1, t):
		exp_val = y_t_minus_1 + alpha + beta*t/365.25
		# Observational model
		y_ = ny.sample("y", dist.Normal(exp_val, sigma))
		
		# Recursive update
		y_t_minus_1 = y_
		return y_t_minus_1, y_
	
	# It seems like condition+scan may only work along the first axis,
	# so transposing y
	with ny.handlers.condition(data={"y": test}):
		_, ys = scan(
				f=transition_fn, 
				init=level_init, 
				xs=jnp.arange(1, T + future)
		)
	if future > 0:
		ny.deterministic("y_forecast", ys[-future:])

# Fitting the model
kernel = NUTS(hierarchical_rw_model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000, num_chains=2)
mcmc.run(random.PRNGKey(0), y=y[:,:T])

However, when I try to use masking to accomodate data that is not all the same length, things go poorly:

# Generating the data
np.random.seed(42)
T = 800
h = 200
N = 10

mu_alpha = -0.15
log_mu_sigma = np.log(5)
mu_beta = 0.5

alpha = np.random.normal(mu_alpha, 0.35, size=N)
sigma = np.random.gamma(100, np.exp(log_mu_sigma)/100, size=N)
beta = np.random.normal(mu_beta, 0.25, size=N)

trend = beta[:,None]*np.arange(T+h)/365.25
rw_with_drift = np.random.normal(alpha[:,None]+trend, sigma[:,None]).cumsum(-1) 

### MASKING THE DATA #####
eliminate_vals = np.random.choice(np.arange(101), N)
for i in range(N):
	rw_with_drift[i,:eliminate_vals[i]] = np.nan

mask = jnp.array(~np.isnan(rw_with_drift))
##########################

y = jnp.array( rw_with_drift  )

fig, ax = plt.subplots()
ax.plot(np.arange(T+h), y.T, alpha=0.25)
ax.set(xlabel='Time', ylabel='Y',title='RW with Drift and Trend')
ax.axvline(T, color='k', ls='--', label='Training Period Ends')
ax.legend()
plt.show()

# Defining the model
def hierarchical_rw_model(y=None, mask = True, future=0):
	N = 0 if y is None else y.shape[0]
	T = 0 if y is None else y.shape[1]
	level_init = 0 if y is None else y[:,0]
	
	# Global Vars
	mu_alpha = ny.sample("mu_alpha", dist.Normal(0,1))
	sig_alpha = ny.sample("sig_alpha", dist.Exponential(2.5))

	mu_beta = ny.sample("mu_beta", dist.Normal(0,1))
	sig_beta = ny.sample("sig_beta", dist.Exponential(1))

	mu_sigma = ny.sample("mu_sigma", dist.Normal(0,1))
	test = y[:,1:].T
	test_mask = mask[:,1:].T
	
	# Time Series Level vars
	with ny.plate("time_series", N):
		alpha = ny.sample("alpha", dist.Normal(mu_alpha,sig_alpha))
		beta = ny.sample("beta", dist.Normal(mu_beta, sig_beta))
		sigma = ny.sample("sigma", dist.HalfNormal(jnp.exp(mu_sigma)))
	
	def transition_fn(y_t_minus_1, t):
		exp_val = y_t_minus_1 + alpha + beta*t/365.25
		# Observational model
		# current_mask = test_mask[t]
		y_ = ny.sample("y", dist.Normal(exp_val, sigma))
		
		# Recursive update
		y_t_minus_1 = y_
		return y_t_minus_1, y_
	
	# It seems like condition+scan may only work along the first axis,
	# so transposing y
	with ny.handlers.mask(mask = test_mask): ### MASKING HERE ######
		with ny.handlers.condition(data={"y": test}):
			_, ys = scan(
					f=transition_fn, 
					init=level_init, 
					xs=jnp.arange(1, T + future)
			)
	if future > 0:
	        ny.deterministic("y_forecast", ys[-future:])

 # Fitting the model
kernel = NUTS(hierarchical_rw_model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000, num_chains=2)
mcmc.run(random.PRNGKey(0), y=y[:,:T], mask = mask[:,:T])

Use the following code to show the posterior distribution:

# Plotting
samples = mcmc.get_samples()
coords = {"time_series":np.arange(N)}
dims = {"alpha":["time_series"], 
        "beta":["time_series"],
        "sigma":['time_series']}

idata = az.from_numpyro(
    mcmc, 
    coords=coords,
    dims=dims
)

az.plot_trace(idata)
plt.tight_layout()

Thank you for your help!

Could you try to move the mask to the transition_fn? scan works for condition but might not work properly for mask handler.