Multiple kalman filters in numpyro

Hi all,

I’m dealing with some data where I am trying to fit j separate time series. This is a matrix of c_{t,j} where t are the number of timesteps. I’m trying to implement a state-space model like the following:

y_{t,j} = \frac{c_{t,j}^{\lambda_j}-1}{\lambda_j} \text{(a Box-Cox transform)}
y_{t,j} \sim \mathcal{N}(\sum_r{Z_{t, r} \beta_{t, j, r}}, \sigma_{\epsilon}^2)

where \beta_{t, j, r} are random walks of the form:

\beta_{t, j, r} \sim \mathcal{N}(\beta_{t-1, j, r}, \sigma_{\eta}^2)

And Z is a matrix like (for r=3):

\begin{pmatrix} 1 & 0 & 0\\ 1 & 1 & 0\\ 1 & 0 & 1\\ 1 & 0 & 0\\ 1 & 1 & 0\\ ... \end{pmatrix}

Or in words:

  • Use a Box-Cox transform to change the skewed c_{t,j} to normally-distributed data y_{t,j}
  • Estimate y_{t.j} using time-varying coefficients \beta_{t, j, r} (r is the number of regressors/coefficients): a global intercept (first column of Z is always 1) and seasonal terms (in this Z, there are 3 seasonal terms, where the first term is 0 to ensure identifiability, so the overall trend is the trend on the reference season 0)

Given y_{t,j} are now normal, I want to use Kalman filters, which make use of normal-normal conjugacy to perform inference efficiently. Here is my current code

def model(Z, c, P0):
    # Z (n_steps, n_regressors)
    # c (n_steps, n_series)
    # P0 (n_series, n_regressors)
    c = jnp.expand_dims(c, -1) # (n_steps, n_series, 1)
    Z = jnp.expand_dims(Z, -2) # (n_steps, 1, n_regressors)
    nt = c.shape[-3]
    nj = c.shape[-2]
    nr = Z.shape[-1]

    time_plate = numpyro.plate("time", nt, dim=-3)
    series_plate = numpyro.plate("series", nj, dim=-2)
    regressor_plate = numpyro.plate("regressor", nr, dim=-1)

    with series_plate:
        sigma_eps = numpyro.sample("sigma_eps", dist.HalfNormal(1.))
        lambda_ = numpyro.sample("lambda_", dist.Uniform(0.5, 1.0))
    with series_plate, regressor_plate:
        sigma_eta = numpyro.sample("sigma_eta", dist.HalfNormal(1.))

    y = (jnp.power(c, lambda_) - 1) / lambda_ # box-cox transformation

    # this ensures the model is identifiable
    # the global intercept is centred so that on Z[0] * a0 = y[0]
    # the seasonal terms are centred around 0
    a0 = jnp.concatenate([y[0], jnp.zeros((nj, nr - 1))], -1)

    def transition_fn(carry, t):
        at, Pt, log_prob = carry
        yt = y[t] # (n_series, 1)
        Zt = Z[t] # (1, n_regressors)
        vt = yt - jnp.sum(at * Zt, -1, keepdims=True) # (n_series, 1)
        Ft = jnp.sum(Pt * jnp.square(Zt), -1, keepdims=True) + jnp.square(sigma_eps)  # (n_series, 1)
        Kt = Pt * Zt / Ft # (n_series, n_regressors)
        log_prob += -0.5 * jnp.sum(jnp.log(jnp.fabs(Ft) + jnp.square(vt) / Ft)) # (n_series, 1)
        at = at + Kt * vt # (n_series, n_regressors)
        Pt = Pt * (1 - Kt * Zt) + jnp.square(sigma_eta) # (n_series, n_regressors)
        return (at, Pt, log_prob), (at, Pt)

    (_, _, log_prob), (at, Pt) = scan(
        transition_fn, (a0, P0, 0.0), 
    numpyro.factor("kalman_filter_lp", -0.5 * nt * nj * jnp.log(2 * jnp.pi) + log_prob)

There’s a couple of things I don’t like:

  1. Distribution. We’re writing our own custom log probability using numpyro.factor. Is there any way to rewrite the transition_fn as a distribution to avoid this?
  2. No likelihood. This is linked to the problem above. There is no line to have numpyro.sample("obs", dist.X(), obs=y). This also means I can’t do useful things like using Predictive for prior predictive simulation or forecasting. Is it possible to use the condition handler?
  3. Plates. The above model doesn’t feel very numpyro-nic. I’ve written the transition function as matrices, but its all element-wise. If it’s possible, I’d much rather write a series-specific transition function by removing the -2 series dimension, and do something like:
with series_plate:
    (_, _, log_prob), (at, Pt) = scan(
        transition_fn, (a0, P0, 0.0), 

Is there any way I could rewrite this model to follow numpyro best practices?

Also, this Kalman filter setup is very similar to the walker package. If we’re able to get a generalised numpyro version running and we think it’s useful for others, I’d be happy to contribute it.



You can define your own distribution with the same implementation:

class CustomDist(dist.Distribution):
    def log_prob(self, y):
        # move the scan code to here

    def sample(self, key):
        # implement the simulation

But it might be unnecessary. You can use Predictive with numpyro.deterministic rather than numpyro.sample.

For batching, it is better to use jax.vmap. I typically use numpyro.util.soft_vmap for an arbitrary number of batch dimensions.

Hey, I’m struggling a bit with this one. I created a custom distribution, which is the Kalman filter in 1D. ie, for a single series and a single regressor (column of the Z matrix).

class KalmanWalker(dist.Distribution):
    """Univariate case: to be mapped over n_series and n_regressors
    def __init__(self, Z, scale_noise=1.0, scale_disturbance=1.0, a0=0., P0=1., num_steps=1, *, validate_args=None):
        assert (
            isinstance(num_steps, int) and num_steps > 0
        ), "`num_steps` argument should be an positive integer."
        assert Z.size == num_steps, "Z is a vector of length num_steps."
        self.scale_noise = scale_noise
        self.scale_disturbance = scale_disturbance
        self.Z = Z
        self.a0 = a0
        self.P0 = P0
        self.num_steps = num_steps
        batch_shape, event_shape = jnp.shape(scale_disturbance), (num_steps,)
        super(KalmanWalker, self).__init__(
            batch_shape, event_shape, validate_args=validate_args
    def log_prob(self, value):
        def transition_fn(carry, t):
            at, Pt, log_prob = carry
            yt = value[t]
            Zt = self.Z[t]
            vt = yt - Zt * at
            Ft = Pt * jnp.square(Zt) + jnp.square(self.scale_noise)
            Kt = Pt * Zt / Ft
            log_p += -0.5 * jnp.sum(jnp.log(jnp.fabs(Ft) + jnp.square(vt) / Ft))
            at = at + Kt * vt
            Pt = Pt * (1 - Kt * Zt) + jnp.square(self.scale_disturbance)

            return (at, Pt, log_prob), (at, Pt)
        (at, Pt, log_p), (at, Pt) = scan(
            (self.a0, self.P0, 0.0), 
        return -0.5 * self.num_steps * jnp.log(2 * jnp.pi) + log_p

    def sample(self, key):
        def transition_simulation(carry, t):
            at, Pt = carry

            Zt = self.Z[t]
            vt = Zt * at
            Ft = Pt * jnp.square(Zt) + jnp.square(self.scale_noise)
            Kt = Pt * Zt / Ft
            at = at + Kt * vt
            Pt = Pt * (1 - Kt * Zt) + jnp.square(self.scale_disturbance)

            return (at, Pt), (at, Pt)
        _, (at, Pt) = scan(
            (self.a0, self.P0),

        # at the moment, there is no stochastic process
        return at

I am confused whether you think I should be:

  • extending this custom KalmanWalker distribution to an arbitrary batch shape, similar to the GaussianRandomWalk implementation. For scan to work, it would be necessary to keep time in the leftmost axis. However, I believe it will be tricky in this case to perform operations like jnp.sum(at * Zt, -1, keepdims=True) when at could now have an arbitrary shape.
  • keeping the univariate version as in this post but use vmap operate over the number of series and number of regressors. e.g. rather than the elementwise matrix multiplication at * Zt, we can map over each element like at[1, 2] * Zt[2]. I think this is what you intended but I’m unsure how this would look in the model.

Also, with the custom distribution, is it possible to write something like numpyro.sample("obs", KalmanWalker(...), obs=y) or a similar numpyro.handlers.condition statement (if that is more appropriate for time series)?

Because batch_shape = scale_disturbance.shape, you should use vmap over scale_disturbance (assume that other parameters like scale_noise, a0,… do not have batch shape).

is it possible to write something like numpyro.sample("obs", KalmanWalker(...), obs=y)

That’s numpyro syntax for any distribution. I’m not sure why you need condition here. If you want to condition on the past data and forecast, you will need to do filtering to get a_T, P_T (where T is the number of past time steps) and then sampling.