I am working on a model of time series where each time series may be of different length. For concreteness, I am modeling the outputs of different instruments over time where each of the instruments is brought online at a different time. A simplified version of the model is
where \mathbf{z}_i\in\mathbb{R}^T is the latent time series for the i^\text{th} instrument, T is the length of the observation period, \tau_i is the scale of the random walk innovations of the i^\text{th} instrument, y_{it} is the observation for the i^\text{th} instrument at time t, and \sigma_i is the scale of the observation noise for the i^\text{th} instrument.
Say the first instrument is brought online at the beginning of the observation period, then I want to introduce T parameters for \mathbf{z}_1. If the second instrument is brought online half-way through the observation period, I want T/2 parameters for \mathbf{z}_2. Such a jagged structure is of course difficult in numpyro
.
What I’ve considered so far
One approach is to forget about the problem and parameterize \mathbf{z} with a dense parameter matrix. But that means I use about twice the number of parameters I actually need.
Another approach is to sample the time series for each instrument independently in a for
loop. But I have about 1,000 instruments, and unrolling the loop in jax.jit
seems problematic.
I could sample one big vector \xi of length m=\sum_{i=1}^n T_i, where n is the number of instruments and T_i is the number of observations I have for instrument i. Then construct indices to scatter \xi into the dense matrix z. That seems fiddly for two reasons: first, constructing the indices; second, placing an appropriate prior on \xi such that z has the desired distribution.
Ideas on how to best parameterize the model would be great; thanks for your time!