I wrote down a model for a multivariate OU process (see below). How can I infer this forecasting model on multiple independent realizations of the same process?
I saw in this tutorial that one can do that using the ‘origin’ plate. However, I am not sure how to adapt their solution to my multivariate case (the forecasting module does some internal reshaping which I cannot follow). Do you have any suggestions? Thanks!
class Model(ForecastingModel):
def model(self,zero_data,covariates):
duration, dim = zero_data.shape[-2:]
#init
init_dist = dist.Normal(torch.zeros(dim), 100).to_event(1)
# transition
trans_mat= pyro.sample("transition_matrix", dist.Dirichlet(torch.ones(dim)).expand([dim]).to_event(1))#.transpose(-2,-1)
trans_mat=(torch.ones(dim).diag_embed()+trans_mat)/2
trans_scale = pyro.sample("trans_scale", dist.LogNormal(torch.zeros(dim), 1).to_event(1)).diag_embed()
trans_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=trans_scale)
# observation model
obs_mat = torch.eye(dim)
obs_scale = pyro.sample("obs_scale", dist.LogNormal(torch.zeros(dim), 1).to_event(1)).diag_embed()
obs_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=obs_scale)
noise_model = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration)
assert noise_model.event_shape == (duration, dim)
self.predict(noise_model,zero_data)
I think the formulation seems to be correct. What’s the error that you got?
It works for a single multivariate trajectory, but it gives the following error when I want to infer the model on multiple trajectories:
ValueError: at site "residual", invalid log_prob shape
Expected [], actual [n_trajectories, 1]
Try one of the following fixes:
- enclose the batched tensor in a with pyro.plate(...): context
- .to_event(...) the distribution being sampled
- .permute() data dimensions
outside the forecasting module I can simply solve this with a data plate:
noise_dist=dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration)
with pyro.plate("data", N): pyro.sample('obs', noise_dist, obs=data)
How should I adapt the above code to satisfy the requirements of the forecasting module? Or more precisely, what should I input to the ‘self.predict()’ function?
I think you can simply put your .predict
call under plate statements as in the later part of the tutorial. What’s the error that you got?
You are right, and I needed to specify the right dimension. The following model works:
class Model(ForecastingModel):
def model(self,zero_data,covariates):
N, duration, dim = zero_data.shape
#init
init_dist = dist.Normal(torch.zeros(dim), 100).to_event(1)
# transition
trans_mat= pyro.sample("transition_matrix", dist.Dirichlet(torch.ones(dim)).expand([dim]).to_event(1))#.transpose(-2,-1)
trans_mat=(torch.ones(dim).diag_embed()+trans_mat)/2
trans_scale = pyro.sample("trans_scale", dist.LogNormal(torch.zeros(dim), 1).to_event(1)).diag_embed()
trans_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=trans_scale)
# observation model
obs_mat = torch.eye(dim)
obs_scale = pyro.sample("obs_scale", dist.LogNormal(torch.zeros(dim), 1).to_event(1)).diag_embed()
obs_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=obs_scale)
noise_model = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration)
assert noise_model.event_shape == (duration, dim)
with pyro.plate("data", N, dim=-3): self.predict(noise_model,zero_data)
Thank you very much for the help!
Looking like the time plate is -1 and your data plate is -2. Could you try to set dim=-2 to your data plate?
It does work with dim=-2 too. I am still confused about how these dimensions work. I will dig into the tutorial again in search of enlightenment.
You can also find relevant explanation in the docs of ForecastingModel
1 Like