I’m trying to adopt the Forecasting III Model2 to a sports setting. I’ve changed the model to be a simple linear regression on one covariate similar to in Forecasting I Model1 through Model3. So for training covariates is of size [32, 31, 15, 1] and zero_data has size [32, 1, 15, 1] which corresponds to [number of offense teams, number of defense teams, time, 1]. For training the following code runs fine:
class Model1(ForecastingModel): # We then implement the .model() method. Since this is a generative model, it shouldn't # look at data; however it is convenient to see the shape of data we're supposed to # generate, so this inputs a zeros_like(data) tensor instead of the actual data. def model(self, zero_data, covariates): no_teams, _, duration, _ = zero_data.size() _, no_def, _, _ = covariates.size() offense_plate = pyro.plate("offense", no_teams, dim=-4) defense_plate = pyro.plate("defense", no_def, dim=-3) # The first part of the model is a probabilistic program to create a prediction. # We use the zero_data as a template for the shape of the prediction. with offense_plate: bias = pyro.sample("bias", dist.Normal(hyper_param_a, hyper_param_b)) with defense_plate: weight = pyro.sample("weight", dist.Normal(0, 0.1)) prediction = bias + (weight * covariates).sum(-3, keepdim=True) # The prediction should have the same shape as zero_data (duration, obs_dim), # but may have additional sample dimensions on the left. assert prediction.shape[-4:] == zero_data.shape[-4:] # The next part of the model creates a likelihood or noise distribution. # Again we'll be Bayesian and write this as a probabilistic program with # priors over parameters. with offense_plate: noise_scale = pyro.sample("noise_scale", dist.LogNormal(-5, 5)) noise_dist = dist.Normal(0, noise_scale) #set_trace() # The final step is to call the .predict() method. with offense_plate: set_trace() self.predict(noise_dist, prediction)
But when running
samples = forecaster(data[...,T0:T1,:], covariates, num_samples=20)
I get the following error:
ValueError: Shape mismatch inside plate('offense') at site residual dim -4, 32 vs 20
Any help would be greatly appreciated!
Pyro version ‘1.3.1’