Mixture HMM model with issues

I tried to implement a mixture of HMM model. It is similar to what is mentioned in this paper with a simple multinomial

Also, I used the ‘model_1’ in HMM example as my start point:
Basically, I just add a latent variable for membership assignment of HMM, which is similar to Gaussian Mixture Model. It should work with the same music dataset.

However, to figure out the correct shape is more complicated than my ability. Here is my failed code:

from __future__ import absolute_import, division, print_function

import argparse
import logging

import torch
import torch.nn as nn
from torch.distributions import constraints

import dmm.polyphonic_data_loader as poly
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.contrib.autoguide import AutoDelta
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, infer_discrete, TracePredictive, config_enumerate
from pyro.optim import Adam
from pyro.util import ignore_jit_warnings

import numpy as np
import matplotlib.pyplot as plt

from generate_data import generate_data

logging.basicConfig(format='%(relativeCreated) 9d %(message)s', level=logging.INFO)

# with user group
def model_2(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences,)
        assert lengths.max() <= max_length
    with poutine.mask(mask=include_prior):
        p = pyro.sample('user_group', dist.Dirichlet(0.5 * torch.ones(args.user_dim)))
        with pyro.plate('user_plate', args.user_dim, dim=-5):
            probs_x = pyro.sample("probs_x",
                                  dist.Dirichlet(0.5 * torch.ones((args.hidden_dim, args.hidden_dim)))
            probs_y = pyro.sample("probs_y",
                              dist.Beta(0.1, 0.9)
                                  .expand([args.hidden_dim, data_dim])
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        user_ind = pyro.sample('user_ind', dist.Categorical(p))
        user_probs_x = probs_x[user_ind]
        user_probs_y = probs_y[user_ind]
        x = 0
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample("x_{}".format(t), dist.Categorical(user_probs_x[x]),
                            infer={"enumerate": "parallel"})
                with tones_plate:
                    pyro.sample("y_{}".format(t), dist.Bernoulli(user_probs_y[x.squeeze(-1)]),
                                obs=sequences[batch, t])

models = {name[len('model_'):]: model
          for name, model in globals().items()
          if name.startswith('model_')}

models = {name[len('model_'):]: model
          for name, model in globals().items()
          if name.startswith('model_')}

def main(args):
    if args.cuda:

    logging.info('Loading data')
    # data = poly.load_data(poly.JSB_CHORALES)
    if args.model == '1':
        data, p_transit, p_emit = generate_data(args, with_user_group=False, with_station_group=False)
    elif args.model == '2':
        data, p_transit_group, p_emit_group = generate_data(args, with_user_group=True, with_station_group=False)
    import time

    logging.info('-' * 40)
    model = models[args.model]
    logging.info('Training {} on {} sequences'.format(
        model.__name__, len(data['train']['sequences'])))
    sequences = data['train']['sequences']
    lengths = data['train']['sequence_lengths']

    if args.truncate:
        sequences = sequences[:, :args.truncate]
    num_observations = float(lengths.sum())
    guide = AutoDelta(poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_")))
    if args.print_shapes:
        # first_available_dim = -2 if model is model_0 else -3
        first_available_dim = -3
        guide_trace = poutine.trace(guide).get_trace(
            sequences, lengths, args=args, batch_size=args.batch_size)
        model_trace = poutine.trace(
            poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace(
            sequences, lengths, args=args, batch_size=args.batch_size)
    Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
    if args.model == 1: elbo = Elbo(max_plate_nesting=2)
    elif args.model == 2: elbo = Elbo(max_plate_nesting=3)
    optim = Adam({'lr': args.learning_rate})
    svi = SVI(model, guide, optim, elbo)
    losses = []
    for step in range(args.num_steps):
        loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size)
        losses.append(loss / num_observations)
        logging.info('{: >5d}\t{}'.format(step, loss / num_observations))
    train_loss = elbo.loss(model, guide, sequences, lengths, args, include_prior=False)
    logging.info('training loss = {}'.format(train_loss / num_observations))

I got error msg:

    probs_2d = probs.view(-1, self._num_events)
RuntimeError: invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view(). at /pytorch/aten/src/TH/generic/THTensor.cpp:213
Trace Shapes:                
 Param Sites:                
Sample Sites:                
 probs_x dist 2 1 1 1 1 | 3 3
        value 2 1 1 1 1 | 3 3
 probs_y dist 2 1 1 1 1 | 3 7
        value 2 1 1 1 1 | 3 7
Trace Shapes:
 Param Sites:
Sample Sites:

need help!!! Thank you!

To my understanding, it seems to be related to the pyro.markov behavior, which makes it different from Gaussian Mixture.