Hi all,
I’m dealing with some data where I am trying to fit j separate time series. This is a matrix of c_{t,j} where t are the number of timesteps. I’m trying to implement a state-space model like the following:
y_{t,j} = \frac{c_{t,j}^{\lambda_j}-1}{\lambda_j} \text{(a Box-Cox transform)}
y_{t,j} \sim \mathcal{N}(\sum_r{Z_{t, r} \beta_{t, j, r}}, \sigma_{\epsilon}^2)
where \beta_{t, j, r} are random walks of the form:
\beta_{t, j, r} \sim \mathcal{N}(\beta_{t-1, j, r}, \sigma_{\eta}^2)
And Z is a matrix like (for r=3):
\begin{pmatrix} 1 & 0 & 0\\ 1 & 1 & 0\\ 1 & 0 & 1\\ 1 & 0 & 0\\ 1 & 1 & 0\\ ... \end{pmatrix}
Or in words:
- Use a Box-Cox transform to change the skewed c_{t,j} to normally-distributed data y_{t,j}
- Estimate y_{t.j} using time-varying coefficients \beta_{t, j, r} (r is the number of regressors/coefficients): a global intercept (first column of Z is always 1) and seasonal terms (in this Z, there are 3 seasonal terms, where the first term is 0 to ensure identifiability, so the overall trend is the trend on the reference season 0)
Given y_{t,j} are now normal, I want to use Kalman filters, which make use of normal-normal conjugacy to perform inference efficiently. Here is my current code
def model(Z, c, P0):
# Z (n_steps, n_regressors)
# c (n_steps, n_series)
# P0 (n_series, n_regressors)
c = jnp.expand_dims(c, -1) # (n_steps, n_series, 1)
Z = jnp.expand_dims(Z, -2) # (n_steps, 1, n_regressors)
nt = c.shape[-3]
nj = c.shape[-2]
nr = Z.shape[-1]
time_plate = numpyro.plate("time", nt, dim=-3)
series_plate = numpyro.plate("series", nj, dim=-2)
regressor_plate = numpyro.plate("regressor", nr, dim=-1)
with series_plate:
sigma_eps = numpyro.sample("sigma_eps", dist.HalfNormal(1.))
lambda_ = numpyro.sample("lambda_", dist.Uniform(0.5, 1.0))
with series_plate, regressor_plate:
sigma_eta = numpyro.sample("sigma_eta", dist.HalfNormal(1.))
y = (jnp.power(c, lambda_) - 1) / lambda_ # box-cox transformation
# this ensures the model is identifiable
# the global intercept is centred so that on Z[0] * a0 = y[0]
# the seasonal terms are centred around 0
a0 = jnp.concatenate([y[0], jnp.zeros((nj, nr - 1))], -1)
def transition_fn(carry, t):
at, Pt, log_prob = carry
yt = y[t] # (n_series, 1)
Zt = Z[t] # (1, n_regressors)
vt = yt - jnp.sum(at * Zt, -1, keepdims=True) # (n_series, 1)
Ft = jnp.sum(Pt * jnp.square(Zt), -1, keepdims=True) + jnp.square(sigma_eps) # (n_series, 1)
Kt = Pt * Zt / Ft # (n_series, n_regressors)
log_prob += -0.5 * jnp.sum(jnp.log(jnp.fabs(Ft) + jnp.square(vt) / Ft)) # (n_series, 1)
at = at + Kt * vt # (n_series, n_regressors)
Pt = Pt * (1 - Kt * Zt) + jnp.square(sigma_eta) # (n_series, n_regressors)
return (at, Pt, log_prob), (at, Pt)
(_, _, log_prob), (at, Pt) = scan(
transition_fn, (a0, P0, 0.0),
jnp.arange(nt),
length=nt
)
numpyro.factor("kalman_filter_lp", -0.5 * nt * nj * jnp.log(2 * jnp.pi) + log_prob)
There’s a couple of things I don’t like:
- Distribution. We’re writing our own custom log probability using
numpyro.factor
. Is there any way to rewrite the transition_fn as a distribution to avoid this? - No likelihood. This is linked to the problem above. There is no line to have
numpyro.sample("obs", dist.X(), obs=y)
. This also means I can’t do useful things like usingPredictive
for prior predictive simulation or forecasting. Is it possible to use the condition handler? - Plates. The above model doesn’t feel very numpyro-nic. I’ve written the transition function as matrices, but its all element-wise. If it’s possible, I’d much rather write a series-specific transition function by removing the
-2
series dimension, and do something like:
with series_plate:
(_, _, log_prob), (at, Pt) = scan(
transition_fn, (a0, P0, 0.0),
jnp.arange(nt),
length=nt
)
Is there any way I could rewrite this model to follow numpyro best practices?
Also, this Kalman filter setup is very similar to the walker package. If we’re able to get a generalised numpyro version running and we think it’s useful for others, I’d be happy to contribute it.
Cheers,
Theo