Proper use of poutine.mask in a very minimal Markov chain example with heterogeneous start and end times

Hi, I’m new to Pyro and I’m trying to understand how poutine.mask works. So, I simulated a small Markov model (with only a 3-by-3 transmission matrix but no emission matrix). (I’ve reviewed pyro/examples/hmm.py at dev · pyro-ppl/pyro · GitHub but the details there were hard to unpack.) I was able to infer the model parameters without the use of vectorization and masking. However, I ran into some obscure warnings that I couldn’t unpack when I tackled the same problem under vectorization and masking. I’ve been trying for a couple of days while digging through multiple references to no avail. Corrections to the existing code along with some concise comments/explanations would be greatly appreciated. Thank you!

import numpy as onp
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch

from pyro.infer.mcmc import MCMC, NUTS

Build dataset:

nof_states = 3
max_timesteps = 13
nof_samples = 10

recur_prob = 0.8

def simulate_data(nof_states, max_timesteps, nof_samples):
onp.random.seed(2020)
epsilon = (1-recur_prob)/(nof_states-1)
transition_prob = (recur_prob-epsilon)*onp.eye(nof_states) + epsilon

# Create a list to store the (start time, end time)-tuples of state trajectories
start_end_times = []

# Create a list to store state trajectories
state_trajs = []

for i in range(nof_samples):
    
    # Generate a random start time and set the initial entries of the respective state trajectory to -1s
    start_time = onp.random.randint(0, int(0.25*max_timesteps))
    state_traj = [-1 for i in range(start_time)]
    
    # Generate a random number of time steps for the trajectory
    nof_timesteps = onp.random.randint(int(0.25*max_timesteps), int(1.0*max_timesteps))
    for t in range(nof_timesteps):
        if t == 0:
            state = 0     # Set the first state to 0
        else:
            state = onp.random.choice(nof_states, p=transition_prob[state])
        state_traj.append(state)
    
    # Fill the rest of the current trajectory with -1s
    for t in range(len(state_traj), max_timesteps):
        state_traj.append(-1)

    state_traj = onp.stack(state_traj)
    
    start_end_times.append((start_time, start_time+nof_timesteps))
    state_trajs.append(state_traj)

start_end_times = onp.stack(start_end_times)
state_trajs = onp.stack(state_trajs, axis=1)

return torch.tensor(transition_prob), torch.tensor(start_end_times), torch.tensor(state_trajs.T)

data = simulate_data(nof_states, max_timesteps, nof_samples);
data

Without vectorization and masking (this worked):

def slow_model(transition_prior, start_end_times, state_trajs):
nof_states = transition_prior.shape[0]
nof_samples = state_trajs.shape[0]
max_timesteps = state_trajs.shape[1]

transition_prob = pyro.sample('transition_prob',
                              dist.Dirichlet(torch.ones([nof_states, nof_states])/nof_states).to_event(1))

for i in pyro.plate('plate_state_trajs', nof_samples):
    nof_timesteps = start_end_times[i, 1] - start_end_times[i, 0]
    state_traj = state_trajs[i, start_end_times[i, 0]:start_end_times[i, 1]]
    state = 0
    for t in pyro.markov(range(nof_timesteps-1)):     # "t" at 0 is "start_end_times[i, 0]"
        state = pyro.sample('state_{}_{}'.format(i, t),
                            dist.Categorical(transition_prob[state]),
                            infer={'enumerate':'parallel'},
                            obs=state_traj[t+1]
                           )

nuts_kernel = NUTS(slow_model, jit_compile=True, ignore_jit_warnings=True)

mcmc = MCMC(nuts_kernel, num_samples=100)

posterior = mcmc.run(*data)

mcmc.summary(.9)

With vectorization and masking (this failed):

def fast_model(transition_prior, start_end_times, state_trajs):
nof_states = transition_prior.shape[0]
nof_samples = state_trajs.shape[0]
max_timesteps = state_trajs.shape[1]

transition_prob = pyro.sample('transition_prob',
                              dist.Dirichlet(torch.ones([nof_states, nof_states])/nof_states).to_event(1))


with pyro.plate('plate_state_trajs', nof_samples) as batch:
    start_times = start_end_times[batch, 0]
    end_times = start_end_times[batch, 1]
    
    states = state_trajs[batch, 0]
    for t in pyro.markov(range(max_timesteps-1)):            
        with poutine.mask(mask=((t >= start_times) & (t < (end_times-1))).unsqueeze(-1)):
            states = pyro.sample('state_{}'.format(t),
                                 dist.Categorical(transition_prob[states]),
                                 infer={'enumerate':'parallel'},
                                 obs=state_trajs[batch, t+1]
                                )

nuts_kernel = NUTS(fast_model, jit_compile=True, ignore_jit_warnings=True)

mcmc = MCMC(nuts_kernel, num_samples=100)

posterior = mcmc.run(*data)

mcmc.summary(.9)