Independent realizations of the same GaussianHMM model

I wrote down a model for a multivariate OU process (see below). How can I infer this forecasting model on multiple independent realizations of the same process?
I saw in this tutorial that one can do that using the ‘origin’ plate. However, I am not sure how to adapt their solution to my multivariate case (the forecasting module does some internal reshaping which I cannot follow). Do you have any suggestions? Thanks!

    class Model(ForecastingModel):
        def model(self,zero_data,covariates):
            duration, dim = zero_data.shape[-2:]
            
            #init
            init_dist = dist.Normal(torch.zeros(dim), 100).to_event(1)
            
            # transition
            trans_mat= pyro.sample("transition_matrix", dist.Dirichlet(torch.ones(dim)).expand([dim]).to_event(1))#.transpose(-2,-1)
            trans_mat=(torch.ones(dim).diag_embed()+trans_mat)/2
            trans_scale = pyro.sample("trans_scale", dist.LogNormal(torch.zeros(dim), 1).to_event(1)).diag_embed()
            trans_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=trans_scale)
            
            # observation model
            obs_mat = torch.eye(dim)
            obs_scale = pyro.sample("obs_scale", dist.LogNormal(torch.zeros(dim), 1).to_event(1)).diag_embed()
            obs_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=obs_scale)
            
            noise_model = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration)
            assert noise_model.event_shape == (duration, dim)
            self.predict(noise_model,zero_data)

I think the formulation seems to be correct. What’s the error that you got?

It works for a single multivariate trajectory, but it gives the following error when I want to infer the model on multiple trajectories:

ValueError: at site "residual", invalid log_prob shape
Expected [], actual [n_trajectories, 1]
Try one of the following fixes:
- enclose the batched tensor in a with pyro.plate(...): context
- .to_event(...) the distribution being sampled
- .permute() data dimensions

outside the forecasting module I can simply solve this with a data plate:

 noise_dist=dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration)
 with pyro.plate("data", N): pyro.sample('obs', noise_dist, obs=data)

How should I adapt the above code to satisfy the requirements of the forecasting module? Or more precisely, what should I input to the ‘self.predict()’ function?

I think you can simply put your .predict call under plate statements as in the later part of the tutorial. What’s the error that you got?

You are right, and I needed to specify the right dimension. The following model works:

class Model(ForecastingModel):
    def model(self,zero_data,covariates):
        N, duration, dim = zero_data.shape
        
        #init
        init_dist = dist.Normal(torch.zeros(dim), 100).to_event(1)
        
        # transition
        trans_mat= pyro.sample("transition_matrix", dist.Dirichlet(torch.ones(dim)).expand([dim]).to_event(1))#.transpose(-2,-1)
        trans_mat=(torch.ones(dim).diag_embed()+trans_mat)/2
        trans_scale = pyro.sample("trans_scale", dist.LogNormal(torch.zeros(dim), 1).to_event(1)).diag_embed()
        trans_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=trans_scale)
        
        # observation model
        obs_mat = torch.eye(dim)
        obs_scale = pyro.sample("obs_scale", dist.LogNormal(torch.zeros(dim), 1).to_event(1)).diag_embed()
        obs_dist = dist.MultivariateNormal(torch.zeros(dim), scale_tril=obs_scale)
        noise_model = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist, duration=duration)
        assert noise_model.event_shape == (duration, dim)
        with pyro.plate("data", N, dim=-3): self.predict(noise_model,zero_data)   

Thank you very much for the help!

Looking like the time plate is -1 and your data plate is -2. Could you try to set dim=-2 to your data plate?

It does work with dim=-2 too. I am still confused about how these dimensions work. I will dig into the tutorial again in search of enlightenment.

You can also find relevant explanation in the docs of ForecastingModel

1 Like