I’m getting some really weird behavior with the torch.Tensor.cumsum(dim) method when playing around with the DLM tutorial.
Basically, I’m pretty sure the pyro.sample(“drift”, …) statement inside a pyro.plate (technically, time_plate from the contrib.forecasting module) is creating a latent variable of size [1095, 6] (representing 1095 observations in 6 predictor time series). Since the sample is inside a plate, the 1095 rows are treated as i.i.d. local latent variables (in the tutorial, these are latent drift variables).
Then outside the plate, we do a simple cumulative sum down the rows to mimic a trend component for each of the 6 predictor time series. Since the tensor is only 2 dimensions, then tensor.cumsum(0) should give the exact same result as tensor.cumsum(-2), right? And yet, when I run and fit the model, the fit is far superior when I use tensor.cumsum(-2) than when I use tensor.cumsum(0).
I tried each different argument (dim 0 and dim -2) in a simpler model and they both gave near identical fits, so not sure what I’m missing in this model? I also tried dim=0, dim=1, dim=2, dim=-2, dim=-1, and I can’t seem to mimic the fit from dim=-2 even though there should be an equivalent positive number (starting from the left) to dim=-2 (which starts from the right), correct?
The code for the model is below. The zero_data basically has shape [1095, 1] (or slightly less rows depending on the train set boundaries) and covariates has shape [1095, 6] for 5 predictors and 1 intercept.
class DLM(ForecastingModel):
def model(self, zero_data, covariates):
feature_dim = covariates.size(-1)
# Global scale.
drift_scale = pyro.sample("drift_scale", dist.LogNormal(-10, 10).expand([feature_dim]).to_event(1))
# Local latent drift variable sampled as i.i.d.
with self.time_plate:
with poutine.reparam(config={"drift": LocScaleReparam()}):
drift = pyro.sample("drift", dist.Normal(torch.zeros(covariates.size()), drift_scale).to_event(1))
# Outside the plate.
weights = drift.cumsum(dim=-2) #<<ISSUE: dim=0 gives worse fit???>>
pyro.deterministic("weights", weights)
prediction = (weights * covariates).sum(-1, keepdim=True)
pyro.deterministic("prediction", prediction)
# Noise distribution.
scale = pyro.sample("noise_scale", dist.LogNormal(-5, 10).expand([1]).to_event(1))
noise_dist = dist.Normal(0, scale)
self.predict(noise_dist, prediction)