How to improve inference performance of HMM?

Hello there.

I’m generating a time series as follows:

true_locs =   [40.,80.,160.,40.]
true_scales = [5,5,5,5]
trans_p = [
    [0.599, 0.001, 0.399, 0.001],
    [0.001, 0.399, 0.001, 0.599],
    [0.599, 0.001, 0.399, 0.001],
    [0.001, 0.399, 0.001, 0.599]
] 
K = 4
N = 2000

gen_distr = [torch.distributions.Normal(true_locs[k], true_scales[k]) for k in 
range(len(true_locs))]

y = torch.empty(N)
trans_p = torch.tensor(trans_p)
k = 0
for i in range(len(y)):
    y[i] = gen_distr[k].sample((1,))
    k = torch.distributions.Categorical(trans_p[k]).sample((1,)).squeeze()

the time series look then like this:
image

then I’m trying to infer back the locations, scales and the transition matrix by using the following model:

@config_enumerate
def model(y, K):
    weights = pyro.sample('weights', dist.Dirichlet(torch.ones(K,K)/K).to_event(1))
    with pyro.plate('components', K):
        locs = pyro.sample('locs', dist.Normal(0., 10.))
        scales = pyro.sample('scales', dist.LogNormal(0., 2.))

    assignment = 0
    for i, d in pyro.markov(enumerate(y)):
        assignment = pyro.sample(
            'assignment'+str(i),
            dist.Categorical(weights[assignment]),
            infer = {"enumerate": "parallel"})
        pyro.sample('obs'+str(i), dist.Normal(locs[assignment], scales[assignment]), obs=d)

and an AutoDelta guide, and the inference is very slow (like, 4 seconds per step). Do you have any idea how to improve the performance?

Thanks,
Maxim

Hi Maxim,
I’d recommend using Pyro’s highly optimized DiscreteHMM distribution, which hard-codes most of your model. The Hidden Markov Model example shows how to translate an enumerated model like yours into a model employing DiscreteHMM.