I am trying to use the DMM code and modify it to my requirement. Trying to do the following things:
- Number of time steps are different for each sample (sequence)
- At each time step, z_{t} generates not a single vector but a sequence of vectors. The number of vectors that the latent code z_{t} generates is not fixed.
I came up with the following generative model that matches my requirements. I was wondering if anyone can comment on the inference in such a model. Would the guide for DMM work in this case as well? Can I write my model in a different way so as to make inference possible?
def next_state(t, z_prev, transitiion_function):
z_mu, z_sigma = transitiion_function(z_prev)
z_t = pyro.sample('z_{}'.format(t), dist.normal, z_mu, z_sigma)
return z_t
def generate_sequence(z_t, emit, pointer, bias_coin):
hidden = z_t
data = emit.initInput()
ps = []
flip = Variable(torch.Tensor([1]))
while flip.data[0]==1 and pointer <= len(sequence)-1:
hidden, out = emit(data, hidden)
ps.append(out.data)
data = out
# pyro.sample('flip_{}'.format(pointer), dist.bernoulli, bias_coin)?
flip = dist.bernoulli(bias_coin)
pointer = pointer + 1
ps = Variable(torch.stack(ps, dim=1))
return ps, pointer
def model(sequences):
z_dim = 100
transition_dim = 100
data_dim = 39
emission_dim = 100
trans = GatedTransition(z_dim, transition_dim)
emit = Emitter_RNN(data_dim, z_dim, emission_dim)
z_0 = nn.Parameter(torch.zeros(100))
z_prev = z_0
bias_coin = Variable(torch.Tensor([0.5]))
for i in range(len(sequences)):
sequence = sequences[i]
pointer = 0
t = 0
while pointer <= len(sequence)-1:
t = t + 1
z_t = next_state(t, z_prev, trans)
ps, pointer = generate_sequence(z_t, emit, pointer, bias_coin)
pyro.sample('obs_{}'.format(t), dist.bernoulli, ps)
z_prev = z_t
PS: the flip
statement in the generate_sequence
is making the while loop stochastic. I am generating latent states as long as the sequence exists.
Thanks for any insights regarding inference in this model.