Hi, I’m new to Pyro and I’m trying to understand how poutine.mask works. So, I simulated a small Markov model (with only a 3-by-3 transmission matrix but no emission matrix). (I’ve reviewed pyro/examples/hmm.py at dev · pyro-ppl/pyro · GitHub but the details there were hard to unpack.) I was able to infer the model parameters without the use of vectorization and masking. However, I ran into some obscure warnings that I couldn’t unpack when I tackled the same problem under vectorization and masking. I’ve been trying for a couple of days while digging through multiple references to no avail. Corrections to the existing code along with some concise comments/explanations would be greatly appreciated. Thank you!
import numpy as onp
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torchfrom pyro.infer.mcmc import MCMC, NUTS
Build dataset:
nof_states = 3
max_timesteps = 13
nof_samples = 10recur_prob = 0.8
def simulate_data(nof_states, max_timesteps, nof_samples):
onp.random.seed(2020)
epsilon = (1-recur_prob)/(nof_states-1)
transition_prob = (recur_prob-epsilon)*onp.eye(nof_states) + epsilon# Create a list to store the (start time, end time)-tuples of state trajectories start_end_times = [] # Create a list to store state trajectories state_trajs = [] for i in range(nof_samples): # Generate a random start time and set the initial entries of the respective state trajectory to -1s start_time = onp.random.randint(0, int(0.25*max_timesteps)) state_traj = [-1 for i in range(start_time)] # Generate a random number of time steps for the trajectory nof_timesteps = onp.random.randint(int(0.25*max_timesteps), int(1.0*max_timesteps)) for t in range(nof_timesteps): if t == 0: state = 0 # Set the first state to 0 else: state = onp.random.choice(nof_states, p=transition_prob[state]) state_traj.append(state) # Fill the rest of the current trajectory with -1s for t in range(len(state_traj), max_timesteps): state_traj.append(-1) state_traj = onp.stack(state_traj) start_end_times.append((start_time, start_time+nof_timesteps)) state_trajs.append(state_traj) start_end_times = onp.stack(start_end_times) state_trajs = onp.stack(state_trajs, axis=1) return torch.tensor(transition_prob), torch.tensor(start_end_times), torch.tensor(state_trajs.T)
data = simulate_data(nof_states, max_timesteps, nof_samples);
data
Without vectorization and masking (this worked):
def slow_model(transition_prior, start_end_times, state_trajs):
nof_states = transition_prior.shape[0]
nof_samples = state_trajs.shape[0]
max_timesteps = state_trajs.shape[1]transition_prob = pyro.sample('transition_prob', dist.Dirichlet(torch.ones([nof_states, nof_states])/nof_states).to_event(1)) for i in pyro.plate('plate_state_trajs', nof_samples): nof_timesteps = start_end_times[i, 1] - start_end_times[i, 0] state_traj = state_trajs[i, start_end_times[i, 0]:start_end_times[i, 1]] state = 0 for t in pyro.markov(range(nof_timesteps-1)): # "t" at 0 is "start_end_times[i, 0]" state = pyro.sample('state_{}_{}'.format(i, t), dist.Categorical(transition_prob[state]), infer={'enumerate':'parallel'}, obs=state_traj[t+1] )
nuts_kernel = NUTS(slow_model, jit_compile=True, ignore_jit_warnings=True)
mcmc = MCMC(nuts_kernel, num_samples=100)
posterior = mcmc.run(*data)
mcmc.summary(.9)
With vectorization and masking (this failed):
def fast_model(transition_prior, start_end_times, state_trajs):
nof_states = transition_prior.shape[0]
nof_samples = state_trajs.shape[0]
max_timesteps = state_trajs.shape[1]transition_prob = pyro.sample('transition_prob', dist.Dirichlet(torch.ones([nof_states, nof_states])/nof_states).to_event(1)) with pyro.plate('plate_state_trajs', nof_samples) as batch: start_times = start_end_times[batch, 0] end_times = start_end_times[batch, 1] states = state_trajs[batch, 0] for t in pyro.markov(range(max_timesteps-1)): with poutine.mask(mask=((t >= start_times) & (t < (end_times-1))).unsqueeze(-1)): states = pyro.sample('state_{}'.format(t), dist.Categorical(transition_prob[states]), infer={'enumerate':'parallel'}, obs=state_trajs[batch, t+1] )
nuts_kernel = NUTS(fast_model, jit_compile=True, ignore_jit_warnings=True)
mcmc = MCMC(nuts_kernel, num_samples=100)
posterior = mcmc.run(*data)
mcmc.summary(.9)