Hi, I’m trying to construct a Bayesian Structural Time Series model with just a local linear trend and a quarterly seasonality. I’m pretty sure I have the model specification set up correctly but it’s giving really bad results when training using SVI on some simulated data.
I’m not sure what could be wrong here, maybe AutoNormal isn’t a good guide for this type of time series model? Or is variational inference not a good method for this type of model?
The model specification is below. It might look a little weird but I’m pretty sure the matrices used are correct (e.g., the season matrix is only 3x3 since it’s constrained so impacts sum to 0 for identifiability, and I verified it from Kevin Murphy’s book). But something must be wrong somewhere since when I take some guide traces from the posteriors on in-sample data, the quantiles around the true values of the parameters (since the data is simulated) are way off.
def model(data):
time_dim = data.shape[0]
data_init, data_mean, data_std = data[0], data.mean(), data.std()
# - Priors.
init_level = pyro.sample('init_level', dist.Normal(data_init, data_std).expand([1]))
init_slope = pyro.sample('init_slope', dist.Normal(0, data_std).expand([1]))
init_seas = pyro.sample('init_seas', dist.Normal(data_mean, data_std).expand([3]).to_event(1))
z = torch.cat([init_level, init_slope, init_seas])
level_scale = pyro.sample('level_scale', dist.HalfNormal(3.).expand([1]))
slope_scale = pyro.sample('slope_scale', dist.HalfNormal(3.).expand([1]))
seas_scale = pyro.sample('seas_scale', dist.HalfNormal(3.).expand([1]))
state_scale = torch.cat([level_scale, slope_scale, seas_scale, torch.ones(2) * 0.001])
obs_scale = pyro.sample('obs_scale', dist.HalfNormal(3.))
# - State dynamics.
trend_state_block = torch.tensor([[1., 1.], [0., 1.]])
seas_state_block = torch.cat([torch.full((1, 3), -1.), torch.eye(3)[:-1, :]])
state_matrix = torch.block_diag(trend_state_block, seas_state_block)
# - Observation dynamics.
trend_obs_block = torch.tensor([[1., 0.]])
seas_obs_block = torch.tensor([[1] + [0] * 2])
obs_matrix = torch.cat([trend_obs_block, seas_obs_block], dim=1)
# - Likelihood.
state_list, obs_list = [], []
for t in range(1, time_dim + 1):
z = pyro.sample(f'z_{t}', dist.Normal(state_matrix @ z, state_scale).to_event(1))
x = pyro.sample(f'x_{t}', dist.Normal(obs_matrix @ z, obs_scale).to_event(1), obs=data[t-1])
state_list.append(z)
obs_list.append(x)
pyro.deterministic('state_list', torch.stack(state_list))
pyro.deterministic('obs_list', torch.stack(obs_list))
Any insight is appreciated. Since I’m fitting through SVI and only pulling posteriors of the latents from in-sample periods, I don’t need to do anything with the Kalman filter, right?