Another update:
- I removed the
obs_plate
(as my output space is single-dimensional) s.t. right dimension is for batch - used Vindex for sampling z
The code:
def model( self, sequences, include_prior=True):
...
pyro.module("state_emitter", self.state_emitter)
pyro.module("ar_emitter", self.ar_emitter)
with poutine.mask(mask=include_prior):
# transition matrix in the hidden state [ num_states X num_states ]
probs_lat = pyro.sample("probs_lat", dist.Dirichlet(
0.5 * torch.eye(self.num_states) + 0.5 / (self.num_states - 1)).to_event(1))
with pyro.plate("sequence_list", size=self.num_seqs, subsample_size=self.batch_size) as batch:
lengths = self.lengths[batch]
z = 0
y = torch.zeros(self.args.batch_size, 1)
input_batch = input_seq[batch, :]
for t in pyro.markov(range(0, self.max_lenght if self.args.jit else self.lengths.max())):
with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
px = self.state_emitter(input_batch[:, t, :], z) # px.shape = [batch_size X num_states]
emitted_x = pyro.sample(f"emit_x_{t}", dist.Categorical(px)) # emitted_x.shape = [batch_size X 1]
z = pyro.sample(f"z_{t}", dist.Categorical(Vindex(probs_lat)[...,emitted_x,:])) z.shape = [batch_size X 1]
py = self.ar_emitter(y, z) # px.shape = [batch_size X num_emission]
y = pyro.sample(f"y_{t}", dist.Categorical(py),obs=output_seq[batch, t])
both runs for the guide and enumeration seems fine.
The trace shapes info is
...
Sample Sites:
probs_lat dist | 6 6
value | 6 6
sequence_list dist |
value 30 |
emit_x_0 dist 30 |
value 6 1 1 |
z_0 dist 6 1 30 |
value 6 1 1 1 |
y_0 dist 6 1 30 |
value 30 1 |
emit_x_1 dist 6 1 30 |
value 6 1 1 1 1 |
z_1 dist 6 1 1 1 30 |
value 6 1 1 1 1 1 |
y_1 dist 6 1 1 1 30 |
value 30 1 |
...
However, I receive in the svi.step
the following bug:
ValueError: at site "emit_x_0", invalid log_prob shape
Expected [30], actual [30, 30]
Try one of the following fixes:
- enclose the batched tensor in a with plate(...): context
- .to_event(...) the distribution being sampled
- .permute() data dimensions
# the svi step
self.elbo = Elbo(max_plate_nesting=2)
optim = Adam({'lr': self.args.learning_rate})
svi = SVI(self.model, self.guide, optim, self.elbo)
# We'll train on small minibatches.
self.logger.info('Step\tLoss')
for step in range(self.args.num_steps):
loss = svi.step(self.sequences)
I don’t really understand what’s wrong, given that the enumeration dimensions make sense.
I can also add the code for ar_emiter
and state_emitter
, however, I doubt that the problem is there.
Is there something wrong with the plates
or to_event
usage?