I’m trying to adopt the Forecasting III Model2 to a sports setting. I’ve changed the model to be a simple linear regression on one covariate similar to in Forecasting I Model1 through Model3. So for training covariates is of size [32, 31, 15, 1] and zero_data has size [32, 1, 15, 1] which corresponds to [number of offense teams, number of defense teams, time, 1]. For training the following code runs fine:
class Model1(ForecastingModel):
# We then implement the .model() method. Since this is a generative model, it shouldn't
# look at data; however it is convenient to see the shape of data we're supposed to
# generate, so this inputs a zeros_like(data) tensor instead of the actual data.
def model(self, zero_data, covariates):
no_teams, _, duration, _ = zero_data.size()
_, no_def, _, _ = covariates.size()
offense_plate = pyro.plate("offense", no_teams, dim=-4)
defense_plate = pyro.plate("defense", no_def, dim=-3)
# The first part of the model is a probabilistic program to create a prediction.
# We use the zero_data as a template for the shape of the prediction.
with offense_plate:
bias = pyro.sample("bias", dist.Normal(hyper_param_a, hyper_param_b))
with defense_plate:
weight = pyro.sample("weight", dist.Normal(0, 0.1))
prediction = bias + (weight * covariates).sum(-3, keepdim=True)
# The prediction should have the same shape as zero_data (duration, obs_dim),
# but may have additional sample dimensions on the left.
assert prediction.shape[-4:] == zero_data.shape[-4:]
# The next part of the model creates a likelihood or noise distribution.
# Again we'll be Bayesian and write this as a probabilistic program with
# priors over parameters.
with offense_plate:
noise_scale = pyro.sample("noise_scale", dist.LogNormal(-5, 5))
noise_dist = dist.Normal(0, noise_scale)
#set_trace()
# The final step is to call the .predict() method.
with offense_plate:
set_trace()
self.predict(noise_dist, prediction)
But when running
samples = forecaster(data[...,T0:T1,:], covariates, num_samples=20)
I get the following error:
ValueError: Shape mismatch inside plate('offense') at site residual dim -4, 32 vs 20
Any help would be greatly appreciated!
Pyro version ‘1.3.1’