HMM with both continuous and discrete observations

Hi,

I’m a bit new to Pyro and have been working on a HMM for protein structure prediction, which will have 4 observations per (discrete) timestep when done:

The observations are as follows:

One continuous observation containing two values. This is sampled from a custom distribution I have implemented (bivariate von mises) and has the look: tensor([ [phi , psi ] ]), so shape (1,2). The values are in radians.

and

Three discrete observations sampled from each their own categorical distribution.

It’s going alright when I only have categorical observations, but I run into massive problems when I try to do the same model but with the continuous observations.

What am I doing wrong?

def model_0(seq_AA, seq_DSSP, seq_dihedral, lengths, args, batch_size=None, include_prior=True):
    assert not torch._C._get_tracing_state()
    num_sequences, max_length, data_dim = seq_AA.shape
    
    data_dim_AA = 20 #for aminoacids
    data_dim_dihedral = 2
    
    
    with poutine.mask(mask=include_prior):

        probs_x = pyro.sample("probs_x",
                              dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1)
                                  .to_event(1))
       
        probs_y_AA = pyro.sample("probs_y_AA",
                              dist.Beta(0.1, 0.9)
                                  .expand([args.hidden_dim, data_dim_AA])
                                  .to_event(2))
        
        probs_y_means = pyro.sample("probs_y_means",
                              dist.VonMises(torch.tensor([0.]), torch.tensor(90.))
                                  .expand([args.hidden_dim, data_dim_dihedral])
                                  .to_event(2))
        
        probs_y_lam = pyro.sample("probs_y_lam",
                              dist.Normal(torch.tensor([0.]), torch.tensor(1.))
                                  .expand([args.hidden_dim, 1])
                                  .to_event(2))
        
        probs_y_k1k2 = pyro.sample("probs_y_k1k2",
                              dist.Gamma(torch.tensor([70.]), torch.tensor(1.))
                                  .expand([args.hidden_dim, 2])
                                  .to_event(2))

    
    for i in pyro.plate("sequences", len(lengths), batch_size):
        length = int(lengths[i].item())
        
        
        sequence_AA = seq_AA[i, :length]
        sequence_dihedral = seq_dihedral[i,:length]
        state_x = 0
        for t in pyro.markov(range(length)):
            
            state_x = pyro.sample("state_x_{}_{}".format(i, t), dist.Categorical(Vindex(probs_x)[..., state_x, :]),
                            infer={"enumerate": "sequential"})
            
            
            
            pyro.sample("y_AA_{}_{}".format(i, t), dist.Categorical(Vindex(probs_y_AA)[...,state_x.squeeze(-1), :]),
                        obs=sequence_AA[t])
            
            
            
            pyro.sample("y_dihedral_{}_{}".format(i,t), BVM.BivariateVonMises(
                mu = Vindex(probs_y_means)[..., state_x.squeeze(-1), 0], 
                nu = Vindex(probs_y_means)[..., state_x.squeeze(-1), 1],
                k1 = Vindex(probs_y_k1k2)[..., state_x.squeeze(-1), 0],
                k2 = Vindex(probs_y_k1k2)[..., state_x.squeeze(-1), 1],
                lam = Vindex(probs_y_lam)[...,state_x.squeeze(-1),:]),
                        obs=(sequence_dihedral[t]).reshape([1,2]))

Edit 1:
I am also getting the error “found vars in model but not in guide” even though I’m using enumeration. Which confuses me a bit. The error is refereing to the samples from probs_x.

My code is taken from the hmm example.

Edit 2:
Sorry I forgot to mention that the final model will have 1 continuous observation and 3 discrete observations, but the current model only has 1 continuous and 1 discrete observation.

Hi @Christian,

First, I realize this doesn’t answer your question :smile: but I’d recommend trying to replace your sequential pyro.markov with a DiscreteHMM distribution as in model_7() of the hmm tutorial. This seems challenging because you have multiple distributions, but I’d be happy to help try and see if we can get something working. The advantage is that DiscreteHMM is 1-2 orders of magnitude faster. I guess I should ask, how long are your sequences?

To address your question, what is your guide? Does your guide contain a sample statement for “probs_x”? Of course Dirichlet distributions are not discrete and so will not be enumerated.

1 Like

Hi Fritzo, thank you for the swift reply :smile:

I am using the AutoDelta guide as in the hmm example. The only thing I’ve changed is the model and the data.

The sequences range in length from ~100 to ~850. I have padded the data to all have the same length, so I will need to use masking in the model. The size of the data set is less than 500 sequences, but I will later use a much larger data set.

Thank you for the advice, I will take another look at the DiscreteHMM and model_7. I must have misunderstand the usage of it the first time around.

Is it ok to have multiple observed distributions at once like what I am trying to achieve? My intuition is that it’s not an issue, but I have not found any Pyro examples showcasing this structure as they all only have one observed output. (4 observations each timestep).

Is it ok to have multiple observed distributions at once like what I am trying to achieve?

It is ok in general Pyro HMM models, but is not currently supported by our faster DiscreteHMM class. However I was recently made aware of @srush’s different implementation of torch_struct.HMM which I believe could be made to work with multiple observation functions (but you’d need to manually call .log_prob() or pyro.block(pyro.trace(...)) or something complex; I’d recommend starting with a sequential HMM using Pyro enumeration).

Update: I now have a fully functioning HMM with 2 continuous and 2 categorical observations. I will put a link to a github (is that ok?) as soon as I have written up my thesis as well as cleaned up the code. If someone comes across this post and could use the code before I clean it up, then message me and I’ll send it directly.

Thank you for your help @fritzo

2 Likes

Also: Changing the learning rate somehow fixed the issue with “found vars in model not in guide”.

Sure, feel free to link to github!