Data parallel HMM question

Hi,

I’m trying to build a HMM model where each sequence has it’s own emission/transition parameters. Starting from a basic HMM:

def model1(data, hidden_dim=2):
    with pyro.plate("hidden_state", hidden_dim):
        transition = pyro.sample("transition", dist.Dirichlet(0.5 * torch.ones(hidden_dim)))
        emission = pyro.sample("emission", dist.HalfNormal(5))

    x = 0 
    for t, y in pyro.markov(enumerate(data)):
        x = pyro.sample("x_{}".format(t), dist.Categorical(transition[x]),
                        infer={"enumerate": "parallel"})

        pyro.sample("y_{}".format(t), dist.Poisson(emission[x]), obs=y)

where data.shape = (sequence_length).

I want to add an outer plate over independent sequences:

def model2(data, hidden_dim=2):
    
    n_data = len(data)
    hidden_plate = pyro.plate("hidden_state", hidden_dim, dim=-1)
    data_plate = pyro.plate("data", n_data, dim=-2)
    
    with data_plate, hidden_plate:
        transition = pyro.sample("transition",  dist.Dirichlet(0.5*torch.ones(hidden_dim)).expand([n_data, hidden_dim]))
        emission = pyro.sample("emission", dist.HalfNormal(5).expand([n_data, hidden_dim]))
        
    with data_plate as i:
        x = torch.zeros((data.shape[0], 1), dtype=torch.long)
        for t in pyro.markov(range(data.shape[1])):

                p_x_new = torch.cat([transition[a, b, :].unsqueeze(0) for a, b in zip(i, x[:, 0])]).unsqueeze(1)
                x_new_dist = dist.Categorical(p_x_new)                                           
                x = pyro.sample("x_{}".format(t), x_new_dist,
                                infer={"enumerate": "parallel"})

                p_emisson = torch.cat([emission[a, b].unsqueeze(0) for a, b in zip(i, x[:, 0])]).unsqueeze(1)
                y = pyro.sample("y_{}".format(t), dist.Poisson(p_emisson), obs=data[i, t].unsqueeze(1))

where now data.shape = (n_data, sequence_length).

model2 seems to run if I use an AutoDiagonalNormal guide and I can use elbo.compute_marginals to get the discrete hidden states which look good. The problem is when I look at guide.quantiles the emission/transition probabilities are not right. The locs never converge and the scale seems to just cover the whole width of the prior. On the other hand using model1 the emission posteriors collapse into a narrow width around the correct value.

Comparing model1 with model2 using a single sequence should give identical results and inspecting the traces of the model and guide for both I found that all of the log_probs are indeed identical for a fixed random seed. However using TraceEnum_ELBO the elbo is different. I also tried using AutoDelta as guide, this time both models has the same elbo.

It looks like the issue is to do with AutoDiagonalNormal and TraceEnum_ELBO, if anyone can give any pointers on what’s going that would be appreciated!

Hi @dnlbunting, my guess is that there is a silent shape bug due to the unsqueezing and reshaping and such. For example, I would guess that to support enumeration and broadcasting you’d have to change

- zip(i, x[:, 0])
+ zip(i, x[..., 0])

and

- unsqueeze(0)  # replace positive dim with negative dim, to allow broadcasting
+ unsqueeze(-2)  # or -1 or whatever negative number is appropriate

Also I find it helps to insert shape assertions after each statement. If I were debugging your model I’d start by moving the time dimension to be on the left so as to avoid unsqueezing. I feel like we could use a tutorial simply on advanced indexing (and I would consult it regularly :smile:).

Hi @fritzo, thanks for replying. I’ve used ellipsis and negative axes but still no joy.

def model(data, hidden_dim=2):
    
    n_data = len(data)
    hidden_plate = pyro.plate("hidden_state", hidden_dim, dim=-1)
    data_plate = pyro.plate("data", n_data, dim=-2)
    with data_plate as i:

        with hidden_plate:
            transition = pyro.sample("transition",  dist.Dirichlet(0.5*torch.ones(hidden_dim)).expand([n_data, hidden_dim]))
            emission = pyro.sample("emission", dist.HalfNormal(5).expand([n_data, hidden_dim]))
        
        x = torch.zeros((data.shape[0], 1), dtype=torch.long)
        for t in pyro.markov(range(data.shape[1])):
            p_x_new = torch.cat([transition[a, b, :].unsqueeze(-2) for a, b in zip(i, x[..., 0])]).unsqueeze(-2)
            x = pyro.sample("x_{}".format(t), dist.Categorical(p_x_new)  ,
                            infer={"enumerate": "parallel"})

            p_emisson = torch.cat([emission[a, b].unsqueeze(-1) for a, b in zip(i, x[..., 0])]).unsqueeze(-1)
            y = pyro.sample("y_{}".format(t), dist.Poisson(p_emisson), obs=data[i, t].unsqueeze(-1))

I’ve been looking at the shapes and I’m a bit confused, the trace looks like this:

    Trace Shapes:            
     Param Sites:            
    Sample Sites:            
hidden_state dist         |  
            value       2 |  
        data dist         |  
            value       3 |  
  transition dist     3 2 | 2
            value     3 2 | 2
    emission dist     3 2 |  
            value     3 2 |  
         x_0 dist     3 1 |  
            value   2 1 1 |  
         y_0 dist   2 3 1 |  
            value     3 1 |  
         x_1 dist   2 3 1 |  
            value 2 1 1 1 |  
         y_1 dist 2 1 3 1 |  
            value     3 1 |  

for a data.shape = (3, 52)

Why do the x’s have a shape of 1 on the -2 (data plate) axis? I would have expected that to be n_data = 3, since p_x_new.shape = (n_data, n_hidden)

Thanks!

The reason is that, when running exact inference using TraceEnum_ELBO, the xs are enumerated, and their set of enumerated values is independent of data; hence Pyro compresses the tensors (via .enumerate_support(expand=False)) and relies on broadcasting to ensure the .log_prob(x) tensors have full shape. By keeping tensors compressed/unbroadcasted, Pyro can do efficient message-passing exact inference over discrete random variables.