I tried to implement a mixture of HMM model. It is similar to what is mentioned in this paper with a simple multinomial
Also, I used the ‘model_1’ in HMM example as my start point:
http://pyro.ai/examples/hmm.html
Basically, I just add a latent variable for membership assignment of HMM, which is similar to Gaussian Mixture Model. It should work with the same music dataset.
However, to figure out the correct shape is more complicated than my ability. Here is my failed code:
from __future__ import absolute_import, division, print_function
import argparse
import logging
import torch
import torch.nn as nn
from torch.distributions import constraints
import dmm.polyphonic_data_loader as poly
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.contrib.autoguide import AutoDelta
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, infer_discrete, TracePredictive, config_enumerate
from pyro.optim import Adam
from pyro.util import ignore_jit_warnings
import numpy as np
import matplotlib.pyplot as plt
from generate_data import generate_data
logging.basicConfig(format='%(relativeCreated) 9d %(message)s', level=logging.INFO)
# with user group
@config_enumerate
def model_2(sequences, lengths, args, batch_size=None, include_prior=True):
with ignore_jit_warnings():
num_sequences, max_length, data_dim = map(int, sequences.shape)
assert lengths.shape == (num_sequences,)
assert lengths.max() <= max_length
with poutine.mask(mask=include_prior):
p = pyro.sample('user_group', dist.Dirichlet(0.5 * torch.ones(args.user_dim)))
with pyro.plate('user_plate', args.user_dim, dim=-5):
probs_x = pyro.sample("probs_x",
dist.Dirichlet(0.5 * torch.ones((args.hidden_dim, args.hidden_dim)))
.to_event(1))
probs_y = pyro.sample("probs_y",
dist.Beta(0.1, 0.9)
.expand([args.hidden_dim, data_dim])
.to_event(2))
tones_plate = pyro.plate("tones", data_dim, dim=-1)
with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
lengths = lengths[batch]
user_ind = pyro.sample('user_ind', dist.Categorical(p))
user_probs_x = probs_x[user_ind]
user_probs_y = probs_y[user_ind]
x = 0
for t in pyro.markov(range(max_length if args.jit else lengths.max())):
with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
x = pyro.sample("x_{}".format(t), dist.Categorical(user_probs_x[x]),
infer={"enumerate": "parallel"})
with tones_plate:
pyro.sample("y_{}".format(t), dist.Bernoulli(user_probs_y[x.squeeze(-1)]),
obs=sequences[batch, t])
models = {name[len('model_'):]: model
for name, model in globals().items()
if name.startswith('model_')}
models = {name[len('model_'):]: model
for name, model in globals().items()
if name.startswith('model_')}
def main(args):
if args.cuda:
torch.set_default_tensor_type('torch.cuda.FloatTensor')
logging.info('Loading data')
# data = poly.load_data(poly.JSB_CHORALES)
if args.model == '1':
data, p_transit, p_emit = generate_data(args, with_user_group=False, with_station_group=False)
elif args.model == '2':
data, p_transit_group, p_emit_group = generate_data(args, with_user_group=True, with_station_group=False)
import time
time.sleep(2)
logging.info('-' * 40)
model = models[args.model]
logging.info('Training {} on {} sequences'.format(
model.__name__, len(data['train']['sequences'])))
sequences = data['train']['sequences']
lengths = data['train']['sequence_lengths']
if args.truncate:
lengths.clamp_(max=args.truncate)
sequences = sequences[:, :args.truncate]
num_observations = float(lengths.sum())
pyro.set_rng_seed(65347)
pyro.clear_param_store()
pyro.enable_validation(True)
guide = AutoDelta(poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_")))
if args.print_shapes:
# first_available_dim = -2 if model is model_0 else -3
first_available_dim = -3
guide_trace = poutine.trace(guide).get_trace(
sequences, lengths, args=args, batch_size=args.batch_size)
model_trace = poutine.trace(
poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace(
sequences, lengths, args=args, batch_size=args.batch_size)
logging.info(model_trace.format_shapes())
Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
if args.model == 1: elbo = Elbo(max_plate_nesting=2)
elif args.model == 2: elbo = Elbo(max_plate_nesting=3)
optim = Adam({'lr': args.learning_rate})
svi = SVI(model, guide, optim, elbo)
logging.info('Step\tLoss')
losses = []
for step in range(args.num_steps):
loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size)
losses.append(loss / num_observations)
logging.info('{: >5d}\t{}'.format(step, loss / num_observations))
train_loss = elbo.loss(model, guide, sequences, lengths, args, include_prior=False)
logging.info('training loss = {}'.format(train_loss / num_observations))
print(guide.median())
I got error msg:
probs_2d = probs.view(-1, self._num_events)
RuntimeError: invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view(). at /pytorch/aten/src/TH/generic/THTensor.cpp:213
Trace Shapes:
Param Sites:
Sample Sites:
probs_x dist 2 1 1 1 1 | 3 3
value 2 1 1 1 1 | 3 3
probs_y dist 2 1 1 1 1 | 3 7
value 2 1 1 1 1 | 3 7
Trace Shapes:
Param Sites:
Sample Sites:
need help!!! Thank you!