Hi,
I’m am trying to follow the neural hmm example to build a DBN for some predictive task (model description bellow).
I am working with mostly discrete data, and using two MLP nn.Module
modules (StateEmitter
and Emitter
) to model the states transition and the emission of the state space model.
The problem occurs after the first batch in the inference.
Code
The MLP modules look like
class Emitter(nn.Module):
...
def forward(self, y, z):
# Check dimension of y so this can be used with and without enumeration.
if y.dim() < 2:
y = y.unsqueeze(0)
# move to onehot representation
z_onehot = self.int2onehot(z, self.num_states, y.dtype, y.device).type(torch.float)
y_onehot = self.int2onehot(y, self.num_categories, y.dtype, y.device, add_batch_dim=True).type(torch.float)
# compute the linear projection of the onehot y_{t-1}. The onehot state vector z will be enumerated
# onehot vectors dim [batch_size, channels, length]
y_conv = self.relu(self.conv_y(y_onehot)).reshape(y.shape[:-1] + (-1,))
# add computed layer, project to y's (output) dimension and turn into probabilities
proposed_alpha = self.lin_hidden_to_y(self.lin_y_to_y_hidden(y_conv) + self.lin_z_to_z_hidden(z_onehot))
alpha = self.softmax(proposed_alpha)
return alpha
class GatedStateTransition(nn.Module):
...
def forward(self, w, z):
if w.dim() < 2:
w = w.unsqueeze(0)
# compute the gating function
_gate = self.relu(self.lin_gate_w_to_hidden(w))
gate = self.sigmoid(self.lin_gate_hidden_to_z(_gate))
_proposed_alpha = self.relu(self.lin_proposed_concentration_w_to_hidden(w))
proposed_alpha = self.lin_proposed_concentration_hidden_to_z(_proposed_alpha)
z_long = torch.Tensor([z]).type(torch.LongTensor) if not torch.is_tensor(z) else z
z_onehot = (
torch.zeros(z_long.shape[:-1] + (self.num_states,), dtype=w.dtype, device=w.device).scatter_(-1, z_long, 1))
alpha = self.softmax((1 - gate) * self.lin_z_to_concentration(z_onehot) + gate * proposed_alpha)
return alpha
def model(self, sequences, include_prior=True):
...
output_dim = output_seq[0].shape[1]
pyro.module("state_emitter", self.state_emitter)
pyro.module("ar_emitter", self.ar_emitter)
with poutine.mask(mask=include_prior):
probs_lat = pyro.sample("probs_lat",
dist.Dirichlet(
0.5 * torch.eye(self.num_states) + 0.5 / (self.num_states - 1)).to_event(1))
obs_plate = pyro.plate("obs", output_dim, dim=-1)
with pyro.plate("sequence_list", self.num_seqs, self.batch_size, dim=-2) as batch:
lengths = self.lengths[batch]
z = 0
y = torch.zeros(self.args.batch_size,1)
for t in pyro.markov(range(0, self.max_lenght if self.args.jit else self.lengths.max())):
with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
emitted_x = pyro.sample(f"emit_x_{t}", dist.Categorical(self.state_emitter(input_seq[batch, t], z)[:,None,:]),
infer={"enumerate": "parallel"})
z = pyro.sample(f"z_{t}", dist.Categorical(probs_lat[emitted_x]),
infer={"enumerate": "parallel"})
with obs_plate:
y = pyro.sample(f"y_{t}", dist.Categorical(self.ar_emitter(y, z)).to_event(1), obs=output_seq[batch, t])
I am using TraceEnum_ELBO
and AutoDelta(poutine.block(self._model, expose=["probs_lat"]))
guide in SVI(model, guide, optim, elbo)
Questions
- The first batch goes fine, the second batch alters
emitted_x
size, where a dimension is added after each batch. The code fails in the second batch (in the last line of the modely=...
) with the error
ValueError: Shape mismatch inside plate('sequence_list') at site y_0 dim -2, 30 vs 6
. I can’t figure out why, but to blame a misuse of enumeration. - Do I use enumeration right? I am mostly discrete, however, I am confused by the MLP modules as inputs to Categorical distributions