Sampling from a trained HMM model

Hi, I have a HMM with Gaussian latent states. Sequences are not equal length so I use poutine.mask. Furthermore each time point in each sequence has a covariate, which I use in my transition probabilities. The model:

def model(lengths, delays, sequences=None, batch_size=None):
    latent_std = pyro.sample('latent_std', dist.Uniform(0.0, 1))
    init = pyro.sample('init', dist.Normal(0, 1.5))
    increment = pyro.sample('increment', dist.Normal(0, 1.0))
    b_delay = pyro.sample('delay', dist.Normal(0, 1.5))

    with pyro.plate('sequences', len(lengths), batch_size) as batch:
        lengths = lengths[batch]
        x = init

        for t in pyro.markov(range(lengths.max())):
            with pyro.poutine.mask(mask=t < lengths):
                x_mean = x + increment + b_delay * delays[batch, t]
                x = pyro.sample(f'x_{t}', dist.Normal(x_mean, latent_std))
                obs = None if sequences is None else sequences[batch, t]
                y = pyro.sample(f'y_{t}', dist.Bernoulli(logits=x), obs=obs)

I train this model on some data and the inferred parameters make sense. I want to do 3 things:

  1. Sample new sequences
    I want to sample from the posterior predictive. As my model has covariates I want to do this sampling wrt to some new covariates. I thought this would work:
lengths = torch.IntTensor([54])
delays = torch.zeros((1, 54))
delays[:, 30] = 5
post_samples = pyro.infer.Predictive(model, guide=guide, num_samples=101)(lengths, delays, None)

However this leads to a vague error that ends with

          value   1 |
 sequences dist     |
          value   1 |
       x_0 dist   1 |
          value 200 |

The model is trained on 200 sequences so I can deduce that the guide also expects 200 new inpus in Predictive. Is there anyway to circumvent this problem?

  1. Samples next states of sequences
    For the sequences that are in my training data, I wish to sample what the trained model thinks the remained of the sequence might look like. Using Predictive on the training data works, but if I look at the remainder of the sequences that are shorter than the max sequence, the predictions do not respect the trained parameters and are just noise:

output

  1. Inference on new sequences
    For new data I want to predict the next time points. This one I imagine is a bit harder. In other threads I understand I have to use amortization, but aside from that is there another way? E.g. can I train on my training data, and then once converged fix the global parameters, and fit the model on the new data to find their latent positions? Using this I can use the solution for problem 2 to do predictions.

Thanks!

I managed to hack out a solution for problem 1. The idea is to sample from the guide a bunch of time with poutine.trace on the new data. The problem is of course the pyro.sample’s that are fixed to the training data, so I use poutine.block to hide them from the trace. I then condition on these samples using poutine.replay.

# New data (only covariates, not observed)
tp = 54
lp = torch.IntTensor([tp] * 200)
dp = torch.zeros((200, tp))
dp[:, 30] = 5

# Sample
hide = [f'{t}_{i}' for i in range(lengths.max()) for t in ('x', 'y')]
hide = [*hide, 'sequences']
traces = []
for _ in tqdm.tqdm(list(range(50))):
    # Get samples of only global parameters
    tr = pyro.poutine.trace(pyro.poutine.block(guide, hide=hide)).get_trace(lp, dp, None)

    # Sample conditioned on those
    rp = pyro.poutine.replay(model, trace=tr)
    tr = pyro.poutine.trace(rp).get_trace(lp, dp)
    values = {
        name: props['value'].item()
        for (name, props) in tr.nodes.items()
        if props['type'] == 'sample'
    }
    traces.append(values)
traces = pd.DataFrame(traces)
traces.head(3)

Is this the correct approach?

I would appreciate any thoughts on this. Thanks so much