Improving SVI approximation for hierarchical time series model with additive degeneracy

I am modeling matrix-valued time series such that the data is an order-3 tensor. These data arise in modeling the number of cases of different infectious diseases in different locations over time. For example, Project Tycho provides weekly data of 50 notifiable diseases in 50 US states from 1888 to 2013.

I am considering a relatively simple tensor factorization model with time series prior on the time-axis. The model is

\begin{aligned} y_{ijt}&\sim\mathsf{Normal}\left(\hat y_{ijt}, \sigma_{ij}\right)\\ \hat y_{ijt}&=\mu + z_t + a_i + A_{it} + b_j + B_{jt}, \end{aligned}

where y_{ijt} are standardized log1p-transformed case numbers of disease j in state i at time t, \mu is a global intercept, z_t captures global temporal variation, a_i are state-level intercepts, A_{it} captures temporal variation in each state, b_j are disease-specific intercepts, and B_{jt} captures temporal variation for each disease. The observation variance \sigma^2_{ij} depends on both the state and the disease. I’ll provide the full prior specification and numpyro implementation below.

The crux is an additive degeneracy in the definition of \hat y such that increasing \mu by \delta leaves the likelihood unchanged if we decrease all z_t by the same amount \delta. The same applies to any combination of these terms. Using an AutoDiagonalNormal guide for variational inference gives overly confident estimates of all parameters because it neglects the anti-correlation between different terms and the correlation among elements of the same term (e.g., all z have to decrease together to compensate for an increase in \mu). Using an AutoMultivariateNormal or AutoLowRankMultivariateNormal is difficult, because the number of parameters is large. Even if we consider only ten years of data, we have 55k parameters (details below).

Is there a good way to capture these correlations? The aim is to make (reasonably) well-calibrated forecasts. Thank you already for the help in Model of time series with different lengths, @martinjankowiak. Your input would be much appreciated!

Prior Specificiation

I use a broad prior on the global intercept, shrinkage priors on state and disease intercepts, and random walk priors on temporal effects. All scale parameters \tau and \sigma have half-Cauchy priors with unit scale. More explicitly,

\begin{aligned} \mu&\sim\mathsf{Normal}\left(0,100\right)\\ \left\{\tau^{(a)},\tau^{(A)}_i,\tau^{(b)},\tau^{(B)}_j,\sigma_{ij}\right\}&\sim\mathsf{HalfCauchy}\left(1\right)\\ a_i&\sim\mathsf{Normal}\left(0,\tau^{(a)}\right)\\ b_j&\sim\mathsf{Normal}\left(0,\tau^{(b)}\right)\\ z_\bullet&\sim\mathsf{RandomWalk}\left(0, \tau^{(z)}\right)\\ A_{i\bullet}&\sim\mathsf{RandomWalk}\left(0, \tau^{(A)}_i\right)\\ B_{j\bullet}&\sim\mathsf{RandomWalk}\left(0, \tau^{(B)}_j\right). \end{aligned}

I have omitted covariates to model seasonality for brevity.

numpyro implementation

Here is the numpyro implementation of the above model.

from jax import numpy as jnp
from numpyro import distributions as dists
from numpyro import handlers
from numpyro import deterministic, plate, sample

def model(n_states, n_diseases, n_weeks, y=None):
    # Sample scale parameters.
    half_cauchy = dists.HalfCauchy(1)
    tau_a = sample("tau_a", half_cauchy)
    tau_b = sample("tau_b", half_cauchy)
    sigma = sample("sigma", half_cauchy.expand([n_states, n_diseases]).to_event(2))

    # Sample intercepts and series.
    mu, z = sample_summands(n_weeks, 100, "mu", "z")
    with plate("n_states", n_states):
        a, A = sample_summands(n_weeks, tau_a, "a", "A")
    with plate("n_diseases", n_diseases):
        b, B = sample_summands(n_weeks, tau_b, "b", "B")

    # Observation model (could be plate-d for subsampling).
    y_hat = numpyro.deterministic(
        mu + z + a[:, None, None] + A[:, None, :] + b[None, :, None] + B[None, :, :]
    sample("y", dists.Normal(y_hat, sigma[..., None]).to_event(3), obs=y)

def sample_summands(n_weeks, tau_x, name_x, name_X):
    Sub-model for sampling `(mu, z)`, `(a, A)`, and `(b, B)`.
    x = sample(name_x, dists.Normal(0, tau_x))
    tau_X = sample(f"tau_{name_X}", dists.HalfCauchy(1))
    X = sample(name_X, dists.GaussianRandomWalk(tau_X, n_weeks))
    return x, X

# Draw a sample from the prior as a sanity check.
n_states = 10
n_diseases = 7
n_weeks = 23
seeded_model = handlers.seed(model, 28)
trace = handlers.trace(seeded_model).get_trace(n_states, n_diseases, n_weeks)
sample = {
    key: value["value"] for key, value in trace.items() 
    if value["type"] == "sample"
shapes = {key: value.shape for key, value in sample.items()}
{'tau_a': (),
 'tau_b': (),
 'sigma': (10, 7),
 'mu': (),
 'tau_z': (),
 'z': (23,),
 'a': (10,),
 'tau_A': (10,),
 'A': (10, 23),
 'b': (7,),
 'tau_B': (7,),
 'B': (7, 23),
 'y': (10, 7, 23)}

Parameter Shapes for Larger Dataset

This section calculates the number of latent parameters for ten years of data from Project Tycho.

import math

n_weeks = 520
n_states = 50
n_diseases = 50
shapes = {
    "mu": (),
    "tau_z": (),
    "z": (n_weeks,),
    "tau_a": (),
    "a": (n_states,),
    "tau_A": (n_states,),
    "A": (n_states, n_weeks),
    "tau_b": (),
    "b": (n_diseases,),
    "tau_B": (n_diseases,),
    "B": (n_diseases, n_weeks),
    "sigma": (n_states, n_diseases),
{'mu': (),
 'tau_z': (),
 'z': (520,),
 'tau_a': (),
 'a': (50,),
 'tau_A': (50,),
 'A': (50, 520),
 'tau_b': (),
 'b': (50,),
 'tau_B': (50,),
 'B': (50, 520),
 'sigma': (50, 50)}
sizes = {key: for key, value in shapes.items()}
sizes, sum(sizes.values())
({'mu': 1,
  'tau_z': 1,
  'z': 520,
  'tau_a': 1,
  'a': 50,
  'tau_A': 50,
  'A': 26000,
  'tau_b': 1,
  'b': 50,
  'tau_B': 50,
  'B': 26000,
  'sigma': 2500},

Thank you for taking the time to read this far!

i think it can be difficult to deal with this kind of issue elegantly.

you could impose sum constraints on different variables to exact remove all the degeneracies. but then any simple parametrization of that space will tend to have terrible geometry.

you could also forget about the issue since sufficiently strong priors will tame the degeneracy to some extent. insofar as you don’t care about the precise posterior estimates over some of these nuisance variables (?) the degeneracy may not matter much from the point of view of interpreting the posterior. e.g. maybe what you care about most is denoised estimates y_hat.

a perhaps wackier idea is to try to use haar transforms one way or the other, something we’ve found to be of some value in pyro for time series modeling, since it can improve the posterior geometry. maybe you can define random variables that exactly obey the sum constraints you want in some auxiliary space and then transform back to the natural space using a haar-transform-like construct, all without messing up the geometry horribly.