Hello! Thank you for maintaining Pyro.
I am working on a hierarchical time series problem where sub-time series aggregate into larger ones, as described in this book.
Thanks to the nice Pyro tutorials I have already fit a simple multivariate time series model, but I do not see how to enforce coherence. I do not understand how to incorporate the S matrix (using the linked book’s notation) into the model, since all the time series (including the ‘summed up’ one) are treated the same in this formulation. I would appreciate any advice on the matter!
Minimal example:
import torch
import pyro
from pyro import distributions as dist
from pyro import poutine
from pyro.contrib.forecast import ForecastingModel, Forecaster, eval_crps
from pyro.infer.reparam import LocScaleReparam, SymmetricStableReparam
from pyro.ops.tensor_utils import periodic_repeat
from pyro.ops.stats import quantile
torch.random.manual_seed(123)
t = torch.arange(50)
w1 = torch.tensor([1, 2, 3, 5, 2, 1, 1])
w2 = torch.tensor([2, 1, 2, 3, 2, 2, 2])
y2 = 25 + 0.3 * t + 0.5*torch.randn(len(t)) + w2[t % 7]
y1 = 10 + 0.15 * t + 0.5*torch.randn(len(t)) + w1[t % 7]
y = y1 + y2
Y = torch.stack((y, y1, y2), 1)
class Model(ForecastingModel):
def model(self, zero_data, covariates):
duration, data_dim = zero_data.shape
drift_stability = pyro.sample("drift_stability", dist.Uniform(1, 2))
drift_scale = pyro.sample("drift_scale", dist.HalfNormal(10))
with pyro.plate("origin", data_dim, dim=-2):
with self.time_plate:
with poutine.reparam(config={"drift": LocScaleReparam()}):
with poutine.reparam(config={"drift": SymmetricStableReparam()}):
drift = pyro.sample("drift", dist.Stable(drift_stability, 0, drift_scale))
with pyro.plate("day_of_week", 7, dim=-1):
seasonal = pyro.sample("seasonal", dist.Normal(0, 10))
seasonal = periodic_repeat(seasonal, duration, dim=-1)
motion = drift.cumsum(dim=-1)
prediction = motion + seasonal
assert prediction.shape[-2:] == (data_dim, duration)
prediction = prediction.unsqueeze(-1).transpose(-1, -3)
assert prediction.shape[-3:] == (1, duration, data_dim), prediction.shape
obs_scale = pyro.sample("obs_scale", dist.HalfNormal(1))
noise_dist = dist.Normal(0, obs_scale.unsqueeze(-1))
self.predict(noise_dist, prediction)
pyro.set_rng_seed(1)
pyro.clear_param_store()
covariates = torch.zeros(len(t), 0, dtype=torch.float)
forecaster = Forecaster(Model(), Y, covariates, learning_rate=0.1, num_steps=1000)
for name, value in forecaster.guide.median().items():
if value.numel() == 1:
print("{} = {:0.4g}".format(name, value.item()))