I am modeling matrix-valued time series such that the data is an order-3 tensor. These data arise in modeling the number of cases of different infectious diseases in different locations over time. For example, Project Tycho provides weekly data of 50 notifiable diseases in 50 US states from 1888 to 2013.

I am considering a relatively simple tensor factorization model with time series prior on the time-axis. The model is

where y_{ijt} are standardized `log1p`

-transformed case numbers of disease j in state i at time t, \mu is a global intercept, z_t captures global temporal variation, a_i are state-level intercepts, A_{it} captures temporal variation in each state, b_j are disease-specific intercepts, and B_{jt} captures temporal variation for each disease. The observation variance \sigma^2_{ij} depends on both the state and the disease. I’ll provide the full prior specification and `numpyro`

implementation below.

The crux is an additive degeneracy in the definition of \hat y such that increasing \mu by \delta leaves the likelihood unchanged if we decrease all z_t by the same amount \delta. The same applies to any combination of these terms. Using an `AutoDiagonalNormal`

guide for variational inference gives overly confident estimates of all parameters because it neglects the anti-correlation between different terms and the correlation among elements of the same term (e.g., all z have to decrease together to compensate for an increase in \mu). Using an `AutoMultivariateNormal`

or `AutoLowRankMultivariateNormal`

is difficult, because the number of parameters is large. Even if we consider only ten years of data, we have 55k parameters (details below).

Is there a good way to capture these correlations? The aim is to make (reasonably) well-calibrated forecasts. Thank you already for the help in Model of time series with different lengths, @martinjankowiak. Your input would be much appreciated!

# Prior Specificiation

I use a broad prior on the global intercept, shrinkage priors on state and disease intercepts, and random walk priors on temporal effects. All scale parameters \tau and \sigma have half-Cauchy priors with unit scale. More explicitly,

I have omitted covariates to model seasonality for brevity.

`numpyro`

implementation

Here is the `numpyro`

implementation of the above model.

```
from jax import numpy as jnp
from numpyro import distributions as dists
from numpyro import handlers
from numpyro import deterministic, plate, sample
def model(n_states, n_diseases, n_weeks, y=None):
# Sample scale parameters.
half_cauchy = dists.HalfCauchy(1)
tau_a = sample("tau_a", half_cauchy)
tau_b = sample("tau_b", half_cauchy)
sigma = sample("sigma", half_cauchy.expand([n_states, n_diseases]).to_event(2))
# Sample intercepts and series.
mu, z = sample_summands(n_weeks, 100, "mu", "z")
with plate("n_states", n_states):
a, A = sample_summands(n_weeks, tau_a, "a", "A")
with plate("n_diseases", n_diseases):
b, B = sample_summands(n_weeks, tau_b, "b", "B")
# Observation model (could be plate-d for subsampling).
y_hat = numpyro.deterministic(
"y_hat",
mu + z + a[:, None, None] + A[:, None, :] + b[None, :, None] + B[None, :, :]
)
sample("y", dists.Normal(y_hat, sigma[..., None]).to_event(3), obs=y)
def sample_summands(n_weeks, tau_x, name_x, name_X):
"""
Sub-model for sampling `(mu, z)`, `(a, A)`, and `(b, B)`.
"""
x = sample(name_x, dists.Normal(0, tau_x))
tau_X = sample(f"tau_{name_X}", dists.HalfCauchy(1))
X = sample(name_X, dists.GaussianRandomWalk(tau_X, n_weeks))
return x, X
# Draw a sample from the prior as a sanity check.
n_states = 10
n_diseases = 7
n_weeks = 23
seeded_model = handlers.seed(model, 28)
trace = handlers.trace(seeded_model).get_trace(n_states, n_diseases, n_weeks)
sample = {
key: value["value"] for key, value in trace.items()
if value["type"] == "sample"
}
shapes = {key: value.shape for key, value in sample.items()}
shapes
```

```
{'tau_a': (),
'tau_b': (),
'sigma': (10, 7),
'mu': (),
'tau_z': (),
'z': (23,),
'a': (10,),
'tau_A': (10,),
'A': (10, 23),
'b': (7,),
'tau_B': (7,),
'B': (7, 23),
'y': (10, 7, 23)}
```

# Parameter Shapes for Larger Dataset

This section calculates the number of latent parameters for ten years of data from Project Tycho.

```
import math
n_weeks = 520
n_states = 50
n_diseases = 50
shapes = {
"mu": (),
"tau_z": (),
"z": (n_weeks,),
"tau_a": (),
"a": (n_states,),
"tau_A": (n_states,),
"A": (n_states, n_weeks),
"tau_b": (),
"b": (n_diseases,),
"tau_B": (n_diseases,),
"B": (n_diseases, n_weeks),
"sigma": (n_states, n_diseases),
}
shapes
```

```
{'mu': (),
'tau_z': (),
'z': (520,),
'tau_a': (),
'a': (50,),
'tau_A': (50,),
'A': (50, 520),
'tau_b': (),
'b': (50,),
'tau_B': (50,),
'B': (50, 520),
'sigma': (50, 50)}
```

```
sizes = {key: math.prod(value) for key, value in shapes.items()}
sizes, sum(sizes.values())
```

```
({'mu': 1,
'tau_z': 1,
'z': 520,
'tau_a': 1,
'a': 50,
'tau_A': 50,
'A': 26000,
'tau_b': 1,
'b': 50,
'tau_B': 50,
'B': 26000,
'sigma': 2500},
55224)
```

Thank you for taking the time to read this far!