 # Proper use of poutine.mask in a very minimal Markov chain example with heterogeneous start and end times

Hi, I’m new to Pyro and I’m trying to understand how poutine.mask works. So, I simulated a small Markov model (with only a 3-by-3 transmission matrix but no emission matrix). (I’ve reviewed https://github.com/pyro-ppl/pyro/blob/dev/examples/hmm.py but the details there were hard to unpack.) I was able to infer the model parameters without the use of vectorization and masking. However, I ran into some obscure warnings that I couldn’t unpack when I tackled the same problem under vectorization and masking. I’ve been trying for a couple of days while digging through multiple references to no avail. Corrections to the existing code along with some concise comments/explanations would be greatly appreciated. Thank you!

import numpy as onp
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch

from pyro.infer.mcmc import MCMC, NUTS

Build dataset:

nof_states = 3
max_timesteps = 13
nof_samples = 10

recur_prob = 0.8

def simulate_data(nof_states, max_timesteps, nof_samples):
onp.random.seed(2020)
epsilon = (1-recur_prob)/(nof_states-1)
transition_prob = (recur_prob-epsilon)*onp.eye(nof_states) + epsilon

``````# Create a list to store the (start time, end time)-tuples of state trajectories
start_end_times = []

# Create a list to store state trajectories
state_trajs = []

for i in range(nof_samples):

# Generate a random start time and set the initial entries of the respective state trajectory to -1s
start_time = onp.random.randint(0, int(0.25*max_timesteps))
state_traj = [-1 for i in range(start_time)]

# Generate a random number of time steps for the trajectory
nof_timesteps = onp.random.randint(int(0.25*max_timesteps), int(1.0*max_timesteps))
for t in range(nof_timesteps):
if t == 0:
state = 0     # Set the first state to 0
else:
state = onp.random.choice(nof_states, p=transition_prob[state])
state_traj.append(state)

# Fill the rest of the current trajectory with -1s
for t in range(len(state_traj), max_timesteps):
state_traj.append(-1)

state_traj = onp.stack(state_traj)

start_end_times.append((start_time, start_time+nof_timesteps))
state_trajs.append(state_traj)

start_end_times = onp.stack(start_end_times)
state_trajs = onp.stack(state_trajs, axis=1)

``````

data = simulate_data(nof_states, max_timesteps, nof_samples);
data

Without vectorization and masking (this worked):

def slow_model(transition_prior, start_end_times, state_trajs):
nof_states = transition_prior.shape
nof_samples = state_trajs.shape
max_timesteps = state_trajs.shape

``````transition_prob = pyro.sample('transition_prob',
dist.Dirichlet(torch.ones([nof_states, nof_states])/nof_states).to_event(1))

for i in pyro.plate('plate_state_trajs', nof_samples):
nof_timesteps = start_end_times[i, 1] - start_end_times[i, 0]
state_traj = state_trajs[i, start_end_times[i, 0]:start_end_times[i, 1]]
state = 0
for t in pyro.markov(range(nof_timesteps-1)):     # "t" at 0 is "start_end_times[i, 0]"
state = pyro.sample('state_{}_{}'.format(i, t),
dist.Categorical(transition_prob[state]),
infer={'enumerate':'parallel'},
obs=state_traj[t+1]
)
``````

nuts_kernel = NUTS(slow_model, jit_compile=True, ignore_jit_warnings=True)

mcmc = MCMC(nuts_kernel, num_samples=100)

posterior = mcmc.run(*data)

mcmc.summary(.9)

With vectorization and masking (this failed):

def fast_model(transition_prior, start_end_times, state_trajs):
nof_states = transition_prior.shape
nof_samples = state_trajs.shape
max_timesteps = state_trajs.shape

``````transition_prob = pyro.sample('transition_prob',
dist.Dirichlet(torch.ones([nof_states, nof_states])/nof_states).to_event(1))

with pyro.plate('plate_state_trajs', nof_samples) as batch:
start_times = start_end_times[batch, 0]
end_times = start_end_times[batch, 1]

states = state_trajs[batch, 0]
for t in pyro.markov(range(max_timesteps-1)):
states = pyro.sample('state_{}'.format(t),
dist.Categorical(transition_prob[states]),
infer={'enumerate':'parallel'},
obs=state_trajs[batch, t+1]
)
``````

nuts_kernel = NUTS(fast_model, jit_compile=True, ignore_jit_warnings=True)

mcmc = MCMC(nuts_kernel, num_samples=100)

posterior = mcmc.run(*data)

mcmc.summary(.9)