Batch Predictions with Forecast Object

Hey! I’ve built a Forecaster object based off the hierarchical time series model tutorial which uses the subsampling method explained to batch train the model, and this works well.

When it comes to predictions however, the data are very large and often results in my kernel dying. To remedy this, I’ve tried to implement a batch prediction method which will iterate across one of the hierarchical levels as follows:

# preprocessed data shape: torch.Size([49, 18316, 14, 1])
# the model takes two hierarchical variables of size 49 and 18316
complete_data = preprocess(df)

# define batch generator
def batch_gen(dataset, batch_size=1000):
    n_samples = dataset.shape[1]
    indices = np.arange(n_samples)
    for start in range(0, n_samples, batch_size):
        end = min(start + batch_size, n_samples)
        batch_idx = indices[start:end]
        yield dataset[:, batch_idx, :, :]

# batch predict
batch_loader = batch_gen(complete_data)
for batch in batch_loader:
    covariates = torch.zeros(batch.size(-2), 0)
    samples = forecaster(
        batch[..., T0:T1, :], 
        covariates[T0:T2], 
        num_samples=100
    )
    samples.clamp_(min=0)
    p5, p50, p95 = quantile(samples[..., 0], (0.05, 0.5, 0.95))
    crps = eval_crps(samples, batch[..., T1:T2, :])
    print(f"CRPS: {crps}")

Only this returns an AssertionError thrown by the line

assert model_value.size(dim) > guide_value.size(dim)

in /anaconda3/lib/python3.7/site-packages/pyro/contrib/forecast/util.py.

From what I gather, I can’t feed a model which has been trained on a given number hierarchical levels, data with a different number of hierarchical levels. In this case, feeding batches of data with 1000 different levels in the second hierarchical variable to a model with a guide that was trained on 18316 levels. Is this correct? If you happen to know of any other ways of batching predictions, or making predictions on subsamples of the hierarchical levels that the model was trained on, that would be hugely appreciated! Thanks in advance!

can’t you parallelize across num_samples?

Hey! Thanks for the quick response!

I wrote up a batch predict function which only returns one sample

samples = forecaster(
    data[..., T0:T1, :], 
    covariates[T0:T2], 
    num_samples=1
)

and then wrapped it up in a vmap and this is now really fast and doesn’t cause any crashes. Thanks for that!

Although I am still wondering if it is possible to make predictions on subsamples of my hierarchical levels? For example in the example provided in the hierarchical time series tutorial, if I pass data for one pair of stations into the model for prediction as follows for example

forecaster(data[:1, :1, T0:T1, :], covariates[T0:T2], num_samples=100)

(As opposed to with data[..., T0:T1, :] as in the example) I get the same AssertionError from

/anaconda3/lib/python3.7/site-packages/pyro/contrib/forecast/util.py in _pyro_post_sample(self, msg)
    145             if model_value.size(dim) != guide_value.size(dim):
    146                 break
--> 147         assert model_value.size(dim) > guide_value.size(dim)
    148         assert model_value.shape[dim + 1 :] == guide_value.shape[dim + 1 :]
    149         split = guide_value.size(dim)

Does this mean that in order to get forecasts for this pair of stations, I have to pass data for all station pairs (i.e. the same size hierarchical levels as the model was trained on)?

i’m not sure i understand your question. @fritzo probably has a better idea, as he wrote the vast majority of code in Forecaster

Ok, thanks.

Returning to the problem of parallelising over subsamples, I have a function which calls the forecaster object with the data and covariates necessary and returns one sample, which I then wrap in a vmap (so each predictive sample comes form the same data and covariates). The issue is that when I call the vmap, it returns the same samples num_samples times. I’ve tried to set different rng_seed within the prediction function, but the results are all still identical. Is there any way around this?

i’ve never used vmap in torch. i’m assuming this is expected behavior. in jax the user explicitly passes around random number seeds when dealing with vmap. why are you using vmap anyway? i thought you’re trying to reduce memory requirements. shouldn’t you be using a for loop?

Indeed, sorry I misunderstood what you meant initially. Using a for loop solves the problem, thanks!

1 Like