I am working on a model of time series where each time series may be of different length. For concreteness, I am modeling the outputs of different instruments over time where each of the instruments is brought online at a different time. A simplified version of the model is

where \mathbf{z}_i\in\mathbb{R}^T is the latent time series for the i^\text{th} instrument, T is the length of the observation period, \tau_i is the scale of the random walk innovations of the i^\text{th} instrument, y_{it} is the observation for the i^\text{th} instrument at time t, and \sigma_i is the scale of the observation noise for the i^\text{th} instrument.

Say the first instrument is brought online at the beginning of the observation period, then I want to introduce T parameters for \mathbf{z}_1. If the second instrument is brought online half-way through the observation period, I want T/2 parameters for \mathbf{z}_2. Such a jagged structure is of course difficult in `numpyro`

.

# What Iâ€™ve considered so far

One approach is to forget about the problem and parameterize \mathbf{z} with a dense parameter matrix. But that means I use about twice the number of parameters I actually need.

Another approach is to sample the time series for each instrument independently in a `for`

loop. But I have about 1,000 instruments, and unrolling the loop in `jax.jit`

seems problematic.

I could sample one big vector \xi of length m=\sum_{i=1}^n T_i, where n is the number of instruments and T_i is the number of observations I have for instrument i. Then construct indices to scatter \xi into the dense matrix z. That seems fiddly for two reasons: first, constructing the indices; second, placing an appropriate prior on \xi such that z has the desired distribution.

Ideas on how to best parameterize the model would be great; thanks for your time!