Problem with enumeration for batch data in Deep markov models

Hi,

I’m am trying to follow the neural hmm example to build a DBN for some predictive task (model description bellow).

pyroardbn

I am working with mostly discrete data, and using two MLP nn.Module modules (StateEmitter and Emitter) to model the states transition and the emission of the state space model.

The problem occurs after the first batch in the inference.

Code

The MLP modules look like

class Emitter(nn.Module):
    ...
    def forward(self, y, z):
        # Check dimension of y so this can be used with and without enumeration.
        if y.dim() < 2:
            y = y.unsqueeze(0)

        # move to onehot representation
        z_onehot = self.int2onehot(z, self.num_states, y.dtype, y.device).type(torch.float)
        y_onehot = self.int2onehot(y, self.num_categories, y.dtype, y.device, add_batch_dim=True).type(torch.float)

        # compute the linear projection of the onehot y_{t-1}. The onehot state vector z will be enumerated
        # onehot vectors dim  [batch_size, channels, length]
        y_conv = self.relu(self.conv_y(y_onehot)).reshape(y.shape[:-1] + (-1,))

        # add computed layer, project to y's (output) dimension and turn into probabilities
        proposed_alpha = self.lin_hidden_to_y(self.lin_y_to_y_hidden(y_conv) + self.lin_z_to_z_hidden(z_onehot))
        alpha = self.softmax(proposed_alpha)
        return alpha

class GatedStateTransition(nn.Module):
    ...
    def forward(self, w, z):

        if w.dim() < 2:
            w = w.unsqueeze(0)

        # compute the gating function
        _gate = self.relu(self.lin_gate_w_to_hidden(w))
        gate = self.sigmoid(self.lin_gate_hidden_to_z(_gate))

        _proposed_alpha = self.relu(self.lin_proposed_concentration_w_to_hidden(w))
        proposed_alpha = self.lin_proposed_concentration_hidden_to_z(_proposed_alpha)

        z_long = torch.Tensor([z]).type(torch.LongTensor) if not torch.is_tensor(z) else z
        z_onehot = (
            torch.zeros(z_long.shape[:-1] + (self.num_states,), dtype=w.dtype, device=w.device).scatter_(-1, z_long, 1))
        alpha = self.softmax((1 - gate) * self.lin_z_to_concentration(z_onehot) + gate * proposed_alpha)
        return alpha

def model(self, sequences, include_prior=True):
        ...
        output_dim = output_seq[0].shape[1]
        pyro.module("state_emitter", self.state_emitter)
        pyro.module("ar_emitter", self.ar_emitter)

        with poutine.mask(mask=include_prior):
            probs_lat = pyro.sample("probs_lat",
                                    dist.Dirichlet(
                                        0.5 * torch.eye(self.num_states) + 0.5 / (self.num_states - 1)).to_event(1))

        obs_plate = pyro.plate("obs", output_dim, dim=-1)
        with pyro.plate("sequence_list", self.num_seqs, self.batch_size, dim=-2) as batch:
            lengths = self.lengths[batch]
            z = 0
            y = torch.zeros(self.args.batch_size,1)
            for t in pyro.markov(range(0, self.max_lenght if self.args.jit else self.lengths.max())):
                with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                    emitted_x = pyro.sample(f"emit_x_{t}", dist.Categorical(self.state_emitter(input_seq[batch, t], z)[:,None,:]),
                                            infer={"enumerate": "parallel"})
                    z = pyro.sample(f"z_{t}", dist.Categorical(probs_lat[emitted_x]),
                                    infer={"enumerate": "parallel"})
                    with obs_plate:
                            y = pyro.sample(f"y_{t}", dist.Categorical(self.ar_emitter(y, z)).to_event(1), obs=output_seq[batch, t])

I am using TraceEnum_ELBO and AutoDelta(poutine.block(self._model, expose=["probs_lat"])) guide in SVI(model, guide, optim, elbo)

Questions

  • The first batch goes fine, the second batch alters emitted_x size, where a dimension is added after each batch. The code fails in the second batch (in the last line of the model y=...) with the error
    ValueError: Shape mismatch inside plate('sequence_list') at site y_0 dim -2, 30 vs 6. I can’t figure out why, but to blame a misuse of enumeration.
  • Do I use enumeration right? I am mostly discrete, however, I am confused by the MLP modules as inputs to Categorical distributions

Hi, I recommend carefully reading our tutorials on enumeration and tensor shapes, especially the section on writing parallelizable code. You should be able to get your model working nicely if you apply the advice in that section about indexing tensors from the right and using the Vindex helper liberally to your code, including to your various helper functions (e.g. int2onehot).

You should also be able to get rid of the slicing you’re performing to compute emitted_x in your model:

...
px = self.state_emitter(input_seq[batch, t], z)
emitted_x = pyro.sample(f"emit_x_{t}", dist.Categorical(px),
                        infer={"enumerate": "parallel"})
...