Problem with enumeration for batch data in Deep markov models

Hi,

I’m am trying to follow the neural hmm example to build a DBN for some predictive task (model description bellow).

pyroardbn

I am working with mostly discrete data, and using two MLP nn.Module modules (StateEmitter and Emitter) to model the states transition and the emission of the state space model.

The problem occurs after the first batch in the inference.

Code

The MLP modules look like

class Emitter(nn.Module):
    ...
    def forward(self, y, z):
        # Check dimension of y so this can be used with and without enumeration.
        if y.dim() < 2:
            y = y.unsqueeze(0)

        # move to onehot representation
        z_onehot = self.int2onehot(z, self.num_states, y.dtype, y.device).type(torch.float)
        y_onehot = self.int2onehot(y, self.num_categories, y.dtype, y.device, add_batch_dim=True).type(torch.float)

        # compute the linear projection of the onehot y_{t-1}. The onehot state vector z will be enumerated
        # onehot vectors dim  [batch_size, channels, length]
        y_conv = self.relu(self.conv_y(y_onehot)).reshape(y.shape[:-1] + (-1,))

        # add computed layer, project to y's (output) dimension and turn into probabilities
        proposed_alpha = self.lin_hidden_to_y(self.lin_y_to_y_hidden(y_conv) + self.lin_z_to_z_hidden(z_onehot))
        alpha = self.softmax(proposed_alpha)
        return alpha

class GatedStateTransition(nn.Module):
    ...
    def forward(self, w, z):

        if w.dim() < 2:
            w = w.unsqueeze(0)

        # compute the gating function
        _gate = self.relu(self.lin_gate_w_to_hidden(w))
        gate = self.sigmoid(self.lin_gate_hidden_to_z(_gate))

        _proposed_alpha = self.relu(self.lin_proposed_concentration_w_to_hidden(w))
        proposed_alpha = self.lin_proposed_concentration_hidden_to_z(_proposed_alpha)

        z_long = torch.Tensor([z]).type(torch.LongTensor) if not torch.is_tensor(z) else z
        z_onehot = (
            torch.zeros(z_long.shape[:-1] + (self.num_states,), dtype=w.dtype, device=w.device).scatter_(-1, z_long, 1))
        alpha = self.softmax((1 - gate) * self.lin_z_to_concentration(z_onehot) + gate * proposed_alpha)
        return alpha

def model(self, sequences, include_prior=True):
        ...
        output_dim = output_seq[0].shape[1]
        pyro.module("state_emitter", self.state_emitter)
        pyro.module("ar_emitter", self.ar_emitter)

        with poutine.mask(mask=include_prior):
            probs_lat = pyro.sample("probs_lat",
                                    dist.Dirichlet(
                                        0.5 * torch.eye(self.num_states) + 0.5 / (self.num_states - 1)).to_event(1))

        obs_plate = pyro.plate("obs", output_dim, dim=-1)
        with pyro.plate("sequence_list", self.num_seqs, self.batch_size, dim=-2) as batch:
            lengths = self.lengths[batch]
            z = 0
            y = torch.zeros(self.args.batch_size,1)
            for t in pyro.markov(range(0, self.max_lenght if self.args.jit else self.lengths.max())):
                with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                    emitted_x = pyro.sample(f"emit_x_{t}", dist.Categorical(self.state_emitter(input_seq[batch, t], z)[:,None,:]),
                                            infer={"enumerate": "parallel"})
                    z = pyro.sample(f"z_{t}", dist.Categorical(probs_lat[emitted_x]),
                                    infer={"enumerate": "parallel"})
                    with obs_plate:
                            y = pyro.sample(f"y_{t}", dist.Categorical(self.ar_emitter(y, z)).to_event(1), obs=output_seq[batch, t])

I am using TraceEnum_ELBO and AutoDelta(poutine.block(self._model, expose=["probs_lat"])) guide in SVI(model, guide, optim, elbo)

Questions

  • The first batch goes fine, the second batch alters emitted_x size, where a dimension is added after each batch. The code fails in the second batch (in the last line of the model y=...) with the error
    ValueError: Shape mismatch inside plate('sequence_list') at site y_0 dim -2, 30 vs 6. I can’t figure out why, but to blame a misuse of enumeration.
  • Do I use enumeration right? I am mostly discrete, however, I am confused by the MLP modules as inputs to Categorical distributions

Hi, I recommend carefully reading our tutorials on enumeration and tensor shapes, especially the section on writing parallelizable code. You should be able to get your model working nicely if you apply the advice in that section about indexing tensors from the right and using the Vindex helper liberally to your code, including to your various helper functions (e.g. int2onehot).

You should also be able to get rid of the slicing you’re performing to compute emitted_x in your model:

...
px = self.state_emitter(input_seq[batch, t], z)
emitted_x = pyro.sample(f"emit_x_{t}", dist.Categorical(px),
                        infer={"enumerate": "parallel"})
...
1 Like

Thanks.
I only got back to the project now. I’ve read the sources and even though the enumeration topic is not that intuitive, I think I spotted part of the problem.
Anyways, I am trying to fix up my code starting for the px

...
with pyro.plate("sequence_list", size=self.num_seqs, subsample_size=self.batch_size, dim=-2) as batch:
   lengths = self.lengths[batch]
   z = 0
   y = torch.zeros(self.args.batch_size, 1)
   for t in pyro.markov(range(0, self.max_lenght if self.args.jit else self.lengths.max())):
      with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
         
         # px.shape = [batch_size X num_states]
         px = self.state_emitter(input_seq[batch, t], z)

         # emitted_x.shape = [batch_size X 1]
         # FIXME: emitted_x.shape  = [batch_size X batch_size]
         emitted_x = pyro.sample(f"emit_x_{t}", dist.Categorical(px))
...

I saw that when the guide is executing the code the first time, I expect emitted_x.shape = (batch_size, 1) since px.shape = (batch_size, num_states).
This happens when I use dist.Categorical(px).sample() ,however pyro.sample(f"emit_x_{t}", dist.Categorical(px)) results with (batch_size, batch_size) object. As I understand, this happens because the sequence_list plate, but I didn’t figure out yet how to work this out

Another update:

  • I removed the obs_plate (as my output space is single-dimensional) s.t. right dimension is for batch
  • used Vindex for sampling z

The code:

def model( self, sequences, include_prior=True):
        ...
        pyro.module("state_emitter", self.state_emitter)
        pyro.module("ar_emitter", self.ar_emitter)

        with poutine.mask(mask=include_prior):
            # transition matrix in the hidden state [ num_states X num_states ]
            probs_lat = pyro.sample("probs_lat", dist.Dirichlet(
                0.5 * torch.eye(self.num_states) + 0.5 / (self.num_states - 1)).to_event(1))
        with pyro.plate("sequence_list", size=self.num_seqs, subsample_size=self.batch_size) as batch:
            lengths = self.lengths[batch]
            z = 0
            y = torch.zeros(self.args.batch_size, 1)
            input_batch = input_seq[batch, :]
            for t in pyro.markov(range(0, self.max_lenght if self.args.jit else self.lengths.max())):
                with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                                      
                    px = self.state_emitter(input_batch[:, t, :], z) # px.shape = [batch_size X num_states]
                    emitted_x = pyro.sample(f"emit_x_{t}", dist.Categorical(px)) # emitted_x.shape = [batch_size X 1]

                    z = pyro.sample(f"z_{t}", dist.Categorical(Vindex(probs_lat)[...,emitted_x,:])) z.shape = [batch_size X 1]

                    py = self.ar_emitter(y, z) # px.shape = [batch_size X num_emission]
                    y = pyro.sample(f"y_{t}", dist.Categorical(py),obs=output_seq[batch, t])

both runs for the guide and enumeration seems fine.
The trace shapes info is

...
      Sample Sites:                       
     probs_lat dist                  | 6 6
              value                  | 6 6
 sequence_list dist                  |    
              value             30   |    
      emit_x_0 dist             30   |    
              value       6  1   1   |    
           z_0 dist       6  1  30   |    
              value     6 1  1   1   |    
           y_0 dist       6  1  30   |    
              value         30   1   |    
      emit_x_1 dist       6  1  30   |    
              value   6 1 1  1   1   |    
           z_1 dist   6 1 1  1  30   |    
              value 6 1 1 1  1   1   |    
           y_1 dist   6 1 1  1  30   |    
              value         30   1   |  
...

However, I receive in the svi.step the following bug:

ValueError: at site "emit_x_0", invalid log_prob shape
  Expected [30], actual [30, 30]
  Try one of the following fixes:
  - enclose the batched tensor in a with plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions
# the svi step
        self.elbo = Elbo(max_plate_nesting=2)
        optim = Adam({'lr': self.args.learning_rate})
        svi = SVI(self.model, self.guide, optim, self.elbo)

        # We'll train on small minibatches.
        self.logger.info('Step\tLoss')
        for step in range(self.args.num_steps):
            loss = svi.step(self.sequences)

I don’t really understand what’s wrong, given that the enumeration dimensions make sense.
I can also add the code for ar_emiter and state_emitter, however, I doubt that the problem is there.
Is there something wrong with the plates or to_event usage?

It sounds like you’re getting the same error as before: emitted_x_0 has shape (30, 30). Debugging shape errors can be tricky, and that doesn’t match the shapes you posted, so it’s hard to say what’s actually going on. In the previous version of your model you set dim=-2 in your plate, which could have explained that error, but it seems like you removed that in the latest version and you’re seeing the same error?

Can you provide a minimal runnable example that reproduces your error, say without the guide, masking, or subsampling and with the neural network components removed or replaced by single linear layers?

Thanks,
I’ve played with the batch dimension, as it made me go further with the debugging. Here I try the batch_dim=-1 and get the error - ValueError: at site "emit_x_0", invalid log_prob shape.
However, all hmm examples use the dim=-2, I guess because they don’t want to be constrained by 1D output dimension. Accordingly, I also tried a dim=-2 version with no success.

Bottom line - using dim=-1 and removing the obs_plate allowed me to move further in debugging to the svi.step (with the aforementioned format_shapes), but failed on the same “double-batch” shape assertion.

I will provide a minimal runnable example today or tomorrow

@eb8680_2 here is a runnable example of the code. I really tried to make close to the original problem.
Here I am trying to work with batch_dim = -1 and I figured out that the problem seems to be the masking and specifically the .unsqueeze(-1).
Do you think I should drop the .unsqueeze(-1)? How should I move to a model with batch_dim=-2?

with poutine.mask(mask=(t < torch.ones(batch_size) * seq_length).unsqueeze(-1))
from unittest import TestCase
import os
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.infer.autoguide import AutoDiagonalNormal, AutoDelta
from pyro.ops.indexing import Vindex

pyro.enable_validation()
pyro.set_rng_seed(0)
num_seq = 100
seq_length = 1000

import torch
from torch import nn as nn


class TestEmitter(nn.Module):
    def __init__(self, w_dim=1, z_dim=1):
        super(TestEmitter, self).__init__()
        self.lin_w_to_hidden = nn.Linear(w_dim, z_dim)
        self.softmax = nn.Softmax(dim=0)

    def forward(self, w, z):
        w = w.unsqueeze(-1)
        w_hidden = self.lin_w_to_hidden(w)
        out = w_hidden
        out_sm = self.softmax(out)  # alpha is size of z
        return out_sm


@config_enumerate
def model(inp):
    num_states = 6
    batch_size = 30
    x = inp['x']
    y = inp['y']

    state_emitter = TestEmitter(w_dim=1, z_dim=6)
    obs_emitter = TestEmitter(w_dim=1, z_dim=25)
    pyro.module("state_emitter", state_emitter)
    pyro.module("obs_emitter", obs_emitter)

    with poutine.mask(mask=True):
        probs_lat = pyro.sample("probs_lat",
                                dist.Dirichlet(0.5 * torch.eye(num_states) + 0.5 / (num_states - 1)).to_event(1))

    z = torch.Tensor([0]).type(torch.FloatTensor)
    y_hat = torch.Tensor([0]).type(torch.FloatTensor)
    with pyro.plate("sequence", size=num_seq, subsample_size=batch_size) as batch:
        for t in pyro.markov(range(0, seq_length - 1)):
            # with poutine.mask(mask=(t < torch.ones(batch_size) * seq_length)):
            with poutine.mask(mask=(t < torch.ones(batch_size) * seq_length).unsqueeze(-1)):
                px = state_emitter(x[t, batch].type(torch.FloatTensor), z)
                z = pyro.sample(f"z_{t}", dist.Categorical(Vindex(probs_lat)[..., px.argmax(dim=1), :]))
                py = obs_emitter(z.type(torch.FloatTensor), y_hat)
                y_hat = pyro.sample(f"y_{t}", dist.Categorical(py), obs=y[t, batch])


def generate_data():
    x = torch.randint(0, 30, size=(seq_length, num_seq))
    y = torch.randint(0, 25, size=(seq_length, num_seq))
    seq = dict(x=x, y=y)
    return seq


def print_shapes(model, guide, seq, first_available_dim=-3):
    guide_trace = poutine.trace(guide).get_trace(seq)
    model_trace = poutine.trace(
        poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace(seq)
    print(model_trace.format_shapes())


class TestPyroARDBN(TestCase):

    def test_architecture(self):
        seq = generate_data()
        guide = AutoDelta(poutine.block(model, expose=["probs_lat"]))
        pyro.clear_param_store()
        elbo = TraceEnum_ELBO(max_plate_nesting=3)
        print_shapes(model, guide, seq, first_available_dim=-3)
        elbo.loss(model, guide, seq)

Do you think i should drop the .unsqueeze(-1)?

Yes, the shape of your masks seems to be the problem - you’re applying a mask of shape (batch_size, 1) to a distribution of batch_shape (batch_size,) (or (num_states, 1, 1, batch_size), causing the shape of site z_0’s log_prob tensor to broadcast to (num_states, 1, batch_size, batch_size). Removing that unsqueeze should fix things.

As a general tip, it’s easier to track down and squash these sorts of bugs if you assign plate dimensions manually and sprinkle your code liberally with shape assertions. I do this whenever I’m developing a complicated model, because Pyro can’t generically validate mask shapes. For example, an assertion that would have caught this bug in the test model is:

seq_plate = pyro.plate("sequence", ..., dim=-1)
with seq_plate as batch:
    ...
    mask = ...
    with poutine.mask(mask=mask):
        px = ...
        z_dist = dist.Categorical(Vindex(probs_lat)[..., px.argmax(dim=1), :])
        assert mask.shape == z_dist.batch_shape[seq_plate.dim:]
        z = pyro.sample(f"z_{t}", z_dist)

We’re hoping that named tensor dimensions in PyTorch 1.3 and Funsor will eventually allow us to dramatically reduce the occurrence of errors like this, which are ultimately caused by the difficulty of keeping track of tensor shapes in your head while developing.

1 Like

FYI @noam I also opened a GitHub issue about validating mask shapes.

Thanks for all the help, you really saved me there.

1 Like