 # HMM-like model with sequences of different lengths

I’m trying to use numpyro’s mask with the following model:

``````z = 0
for i in range(len_seq):
z ~ Normal (z * w + x[i], sigma)
y ~ Bernoulli (sigmoid(beta * z))
``````

where the sequences don’t have the same length.

I get an unexpected result on the shape of the samples (see below).

``````#Example for numpyro.handlers.mask

import numpyro
from numpyro import handlers
import numpy
import jax
import jax.numpy as np
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist

def model(y, x_matrix, lengths):
w = numpyro.sample('w', dist.Uniform(0., 1.))
sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
beta = numpyro.sample('beta', dist.HalfNormal(3.))

with numpyro.plate("data", len(y)):
z = np.zeros(len(y))
for i in range(lengths.max()):
z = numpyro.sample('z_%d'%i, dist.Normal(z * w + x_matrix[:,i], sigma))
numpyro.sample('y', dist.Bernoulli(logits=beta*z), obs=y)

#Define the variables
y = np.array([1, 0, 0, 1, 1, 0]) #observations
n_obs = len(y)
lengths = np.array([4, 4, 6, 6, 6, 10]) #lengths of the sequences
x_dict = {j : numpy.random.normal(loc=0, scale=1, size=size) for j, size in enumerate(lengths)}

#Transform x_dict into a matrix
x_matrix = numpy.zeros((n_obs, lengths.max()))
for j in range(n_obs):
x_matrix[j, :len(x_dict[j])] = x_dict[j]
x_matrix = np.array(x_matrix)

#Run NUTS
rng_key = jax.random.PRNGKey(0)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=100, num_samples=200)
mcmc.run(rng_key, y=y, x_matrix=x_matrix, lengths=lengths)

mcmc_samples = mcmc.get_samples()
print(mcmc_samples.keys()) #dict_keys(['beta', 'sigma', 'w', 'z_0', 'z_1', 'z_2', 'z_3', 'z_4', 'z_5', 'z_6', 'z_7', 'z_8', 'z_9'])
print(mcmc_samples['z_0'].shape) #(200, 6) as expected
print(mcmc_samples['z_9'].shape) #(200, 6), but I would expect (200,1): there is only one sequence of length 10
``````

That makes me think I didn’t use the mask properly. Any hints?

Thanks for the clear example. When you use `handlers.mask()`, the resulting mask is only used during the potential energy (log density) computation to neglect some of the entries in `z`, but you will still see the full matrix captured in the trace, and will need to mask it out manually. To get the masked values in the trace, we will additionally need support for Pyro’s MaskedDistribution. I’ll create a task for this.

My earlier comment about using `lax.scan` in the inner loop is unlikely to work as is since we are sampling (and recording values) inside the loop. Is the time to compile reasonable with this version?

Update: I made some changes to your model which may result in faster compilation depending on how big the inner loop is (if it is only 10-12 like in the examples, it may not have much impact). There may still be bugs in the code, so please double-check the results of inference. • The first change is just hand-writing the the log density contribution of the inner loop using `lax.scan`. Numpyro primitives and effect handlers wouldn’t work inside jax control flow primitives, so this is a way around that. Note that we are nullifying the contribution of `N(0, 1)` in `_body_fn`. If the inner loop is long (max(lengths) is high), using `lax.scan` will result in faster compilation time.
• I am recording the masked out values `zs_masked` so that we can later look at these using `Predictive`. This will be much nicer once we have `numpyro.deterministic` available.
``````import jax
import jax.numpy as np
import numpy
from jax import lax
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

def _body_fn(c, val):
log_prob, w, sigma, lengths = c
z, i, x_matrix = val
log_prob = log_prob + np.sum(dist.Normal(z * w + x_matrix, sigma).log_prob(z) * (i < lengths))
# subtract log_prob contribution from N(0, 1) in the model
log_prob = log_prob - np.sum(dist.Normal(0., 1.).log_prob(z))
return (log_prob, w, sigma, lengths), z * (i < lengths)

def model(y, x_matrix, lengths):
w = numpyro.sample('w', dist.Uniform(0., 1.))
sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
beta = numpyro.sample('beta', dist.HalfNormal(3.))
L = int(np.max(lengths))

with numpyro.plate("y", len(y)):
with numpyro.plate("len", L):
# randomly sampling from N(0, 1) - contribution to PE is subtracted in _body_fn
zs = numpyro.sample('zs', dist.Normal(0., 1.))
c, z_masked = lax.scan(_body_fn, (0., w, sigma, lengths), (zs, np.arange(L), x_matrix.T))
log_prob = c
numpyro.factor('log_prob', log_prob)
# No-op: only meant to record these values
R = numpyro.sample('R', dist.Bernoulli(logits=beta * zs[-1][-1]), obs=y)

# Define the variables
y = np.array([1, 0, 0, 1, 1, 0])  # observations
n_obs = len(y)
lengths = np.array([4, 4, 6, 6, 6, 10])  # lengths of the sequences
x_dict = {j: numpy.random.normal(loc=0, scale=1, size=size) for j, size in enumerate(lengths)}

# Transform x_dict into a matrix
x_matrix = numpy.zeros((n_obs, lengths.max()))
for j in range(n_obs):
x_matrix[j, :len(x_dict[j])] = x_dict[j]
x_matrix = np.array(x_matrix)

# Run NUTS
rng_key = jax.random.PRNGKey(0)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=400, num_samples=500)
mcmc.run(rng_key, y=y, x_matrix=x_matrix, lengths=lengths)

mcmc_samples = mcmc.get_samples()
# Use Predictive to get values for zs_masked
samples = Predictive(model, mcmc_samples)(random.PRNGKey(1), y=y, x_matrix=x_matrix, lengths=lengths)
# Look at a random sample
``````

Let me know if this solves your issue of slow compilation time.

Thank you, using `pyro.handlers.mask` speeds up the compilation time compared to without.

About the use of `handlers.mask()`: in this example, what happens to the variable z of short sequences once `i>len_seq`? I’m wondering what is the z used in the sampling (Bernoulli distribution). My concern is whether the model is fitting the parameters correctly.

About the use of Pyro’s MaskedDistribution: how would the mask be used in this example? I get the same shape issue with the equivalent Pyro code (see below).

``````#Example for pyro.poutine.mask

import pyro
from pyro import poutine
import numpy
import jax
import jax.numpy as np
from pyro.infer import MCMC, NUTS
import pyro.distributions as dist

def model(y, x_matrix, lengths):
w = pyro.sample('w', dist.Uniform(0., 1.))
sigma = pyro.sample('sigma', dist.HalfNormal(3.))
beta = pyro.sample('beta', dist.HalfNormal(3.))

with pyro.plate("data", len(y)):
z = torch.zeros(len(y)).float()
for i in range(lengths.max()):
z = pyro.sample('z_%d'%i, dist.Normal(z * w + x_matrix[:,i], sigma)) #does not work either with: z = pyro.sample('z_%d'%i, dist.Normal(z * w + x_matrix[:,i], sigma).mask(i<lengths))
pyro.sample('y', dist.Bernoulli(logits=beta*z), obs=y)

#Define the variables
y = torch.tensor([1, 0, 0, 1, 1, 0]).float() #observations
n_obs = len(y)
lengths = torch.tensor([4, 4, 6, 6, 6, 10]) #lengths of the sequences
x_dict = {j : torch.randn(size) for j, size in enumerate(lengths)}

#Transform x_dict into a matrix
x_matrix = torch.zeros((n_obs, lengths.max()))
for j in range(n_obs):
x_matrix[j, :len(x_dict[j])] = x_dict[j]

#Run NUTS
kernel = NUTS(model)
mcmc = MCMC(kernel, warmup_steps=15, num_samples=15)
mcmc.run(y=y, x_matrix=x_matrix, lengths=lengths)

mcmc_samples = mcmc.get_samples()
print(mcmc_samples.keys()) #dict_keys(['beta', 'sigma', 'w', 'z_0', 'z_1', 'z_2', 'z_3', 'z_4', 'z_5', 'z_6', 'z_7', 'z_8', 'z_9'])
print(mcmc_samples['z_0'].shape) #(15, 6) as expected
print(mcmc_samples['z_9'].shape) #(15, 6), but I would expect (15,1): there is only one sequence of length 10
``````

About the use of `handlers.mask()` : in this example, what happens to the variable z of short sequences once `i>len_seq` ? I’m wondering what is the z used in the sampling (Bernoulli distribution).

Okay, I see the issue. Masking just ensures that the sequences with `i > len_seq` don’t contribute to the potential energy but `z` will always have `shape = (6,)` in your example. So we need to manually apply the mask when considering samples from HMC.

Now I think, in your example, the last sequence has the max length (is that correct?), and therefore, you expect that the z that we finally get is a scalar. If that’s the case, you can simply pull the last value out using `dist.Bernoulli(logits=beta * z[-1])`. Does that make sense? I have also updated the numpyro snippet above to take this into account.

@neerajprad No, in the Bernoulli sampling z should to be a vector of all the last hidden states (of all sequences). My model is simply an HMM with observation at the end of the sequence:

``````for i_seq in all_seq:
len_seq = lengths[i_seq]
z[i_seq, 0] = 0
for i in range(len_seq):
z[i_seq, i+1] ~ Normal (z[i_seq, i] * w + x_matrix[i_seq, i], sigma)
y[i_seq] ~ Bernoulli (sigmoid(beta * z[i_seq, len_seq]))
``````

and I’m trying to vectorize this model, by using `handlers.mask` or something else.

I’m getting confused with what “`with handlers.mask(i<lengths):`” is doing.
In the end I don’t care about the samples of the variable z, I just want the samples of my parameters (w, sigma and beta) to reflect their posterior distribution with my model and data.

1 Like

I’m not sure this is correct. I don’t know about the `zs[-1][-1]` vs `zs[-1]`. I don’t understand all the code, but I ran it on simulated data and I wasn’t able to recover the parameters that generated the data. Are you sure that there should be zs in `R = numpyro.sample('R', dist.Bernoulli(logits=beta * zs[-1]), obs=y)` rather than z_masked?

Thanks for your code snippet, I understand what the model is doing now. I don’t think you need `mask` in the way we are using it, but I think you could completely vectorize the computation as follows.

Update: I have made multiple revisions to my earlier code based on @vincentbt’s comments and test cases - it verifies the log joint computation between the sequential and vectorized versions now, and we are able to recover the parameter means successfully. The vectorized version should be much faster to compile.

``````from jax import lax
import jax.numpy as np
from jax import random
import numpy as onp
from numpy.testing import assert_allclose

import numpyro
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import NUTS, MCMC
from numpyro.infer.util import get_potential_fn

def sequential(y, x_matrix, lengths):
w = numpyro.sample('w', dist.Uniform(0., 1.))
sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
beta = numpyro.sample('beta', dist.HalfNormal(3.))
L = int(np.max(lengths))

for j, seq in enumerate(x_matrix):
len_seq = lengths[j]
z = 0
for i in range(len_seq):
z = numpyro.sample('z_{}{}'.format(j, i), dist.Normal(z * w + seq[i], sigma))
R = numpyro.sample('R_{}'.format(j), dist.Normal(beta * z, 1.), obs=y[j])

def _body_fn(c, val):
z_prev, w, sigma, lengths = c
z, i, x_matrix = val
n = dist.Normal(z_prev * w + x_matrix, sigma)
# Subtract log prob contribution from N(0, 1) in main body
log_prob = -dist.Normal(0., 1.).log_prob(z)
# This is only needed so that the final z vector returned
# has the last z value for each sequence
z = np.where(i < lengths, z, z_prev)
# Add the log_prob contribution for all sequences for the
log_prob += n.log_prob(z) * (i < lengths)
return (z, w, sigma, lengths), log_prob

def vectorized(y, x_matrix, lengths):
assert len(lengths) == len(y)
w = numpyro.sample('w', dist.Uniform(0., 1.))
sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
beta = numpyro.sample('beta', dist.HalfNormal(3.))
L = int(np.max(lengths))

with numpyro.plate("y", len(y)):
with numpyro.plate("len", L):
# randomly sampling from N(0, 1) - contribution to PE is subtracted in _body_fn
zs = numpyro.sample('zs', dist.Normal(0., 1.))
c, vals = lax.scan(_body_fn,
(np.zeros(len(y)), w, sigma, lengths),
(zs, np.arange(L), x_matrix.T))
z = c
log_prob = vals
numpyro.factor('log_prob', np.sum(log_prob))
numpyro.deterministic('log_p', log_prob)
numpyro.deterministic('z', z)
R = numpyro.sample('R', dist.Normal(beta * z, 1.), obs=y)

def generate_data():
# Simulate data
onp.random.seed(1)
sigma = 0.8
w = 0.69
beta = 2.7
lengths = np.array(np.concatenate([*210 + *220 + *215]))
n_obs = len(lengths)
L = lengths.max()
x_matrix = onp.zeros((n_obs, L))

for i in range(n_obs):
x_matrix[i, :lengths[i]] = onp.random.normal(0., 1., size=(lengths[i],))

model_ = handlers.condition(handlers.seed(sequential, 1), {'w': w, 'sigma': sigma, 'beta': beta})
trace = handlers.trace(model_).get_trace([None] * n_obs, x_matrix, lengths)
y = np.stack([trace['R_{}'.format(i)]['value'] for i in range(n_obs)])
z = np.reshape(np.stack([trace['z_{}{}'.format(i, j)]['value'] if j < lengths[i] else 0
for i in range(n_obs)
for j in range(L)]), (n_obs, L))
# Return values for both inputs / observed, as well as the latents
# (for checking potential energy computation)
return {'sigma': sigma,
'w': w,
'beta': beta,
'lengths': lengths,
'x_matrix': x_matrix,
'z': z,
'y': y,
'trace': trace}

def assert_seq_vec_potential_energy_match(data):
# sequential model: check PE (-log_joint) for generated data point.
seq_sample = {k: v['value'] for k, v in data['trace'].items() if v['type'] == 'sample'}
seq_pe = get_potential_fn(random.PRNGKey(2), sequential,
model_args=(data['y'], data['x_matrix'], data['lengths']),
model_kwargs={})(seq_sample)
# vectorized model: check PE (-log_joint) for generated data point.
vec_pe = get_potential_fn(random.PRNGKey(2), vectorized,
model_args=(data['y'], data['x_matrix'], data['lengths']),
model_kwargs={})(dict(w=data['w'], sigma=data['sigma'], beta=data['beta'],
zs=data['z'].T))
assert_allclose(seq_pe, vec_pe, rtol=1e-3)
print('success: PE is same for vectorized and sequential models.')

def run_inference(model, data):
rng_key = random.PRNGKey(0)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000)
mcmc.run(rng_key, y=data['y'], x_matrix=data['x_matrix'], lengths=data['lengths'])
mcmc_samples = mcmc.get_samples()
print(np.mean(mcmc_samples['sigma'], axis=0))
print(np.mean(mcmc_samples['w'], axis=0))
print(np.mean(mcmc_samples['beta'], axis=0))

print('generating data...')
data = generate_data()
print('run inference using vectorized model...')
assert_seq_vec_potential_energy_match(data)
run_inference(vectorized, data)
print('run inference using sequential model...')
# WARNING: takes too long to compile
run_inference(sequential, data)
``````

Does that roughly address your question? Happy to chat offline in more detail.

@neerajprad I’m just curious (I haven’t looked into the problem in more detail). I understand that we should use non-centering reparameterization together with `lax.scan` here for efficiency. But would it be simpler to only use `mask` on `zs` and use `lax.scan` to compute `z`?

it verifies the log joint computation between the sequential and vectorized versions

I think with non-centering reparameterization, we want to compute p(zs, …), so the difference between two versions should be fine. WDYT?

1 Like

I understand that we should use non-centering reparameterization together with `lax.scan` here for efficiency.

Great point, yes, I think using a non-centered parametrization in `_body_fn` will be much better! My goal here was only to verify that the model could be rewritten using jax primitives to enable fast compilation.

But would it be simpler to only use `mask` on `zs` and use `lax.scan` to compute `z` ?

I’m not sure if that will be much simpler, though I might be misunderstanding your suggestion. The mask operation is just a small detail in `_body_fn`.

I think with non-centering reparameterization, we want to compute p(zs, …), so the difference between two versions should be fine.

Yes, in that case the log density computation would differ (unless we change the sequential version accordingly) but that should be fine. We only have this assertion to convince ourselves that the vectorized version is indeed doing the right thing.

Understood! I thought you were using reparameterization, but after looking into the solution again, I understand that you just want to construct a corresponding one for the sequential model.

For users who are interested in the non-centered reparameterization which I mentioned above, here is the corresponding model (slight modification of @neerajprad’s answer). This technique will mix much better than the centered one. Note that future versions of JAX will include grad rule for while loop, which will allow us to use `lax.fori_loop` instead of `lax.scan` here (not important though).

``````def scan_z(noises, x_matrix, w, lengths):
def _body_fn(z_prev, val):
noise, i, x_col = val
z = z_prev * w + x_col + noise
z = np.where(i < lengths, z, z_prev)
return z, None

return lax.scan(_body_fn, np.zeros(noises.shape),
(noises, np.arange(noises.shape), x_matrix.T))

def reparam_model(y, x_matrix, lengths):
w = numpyro.sample('w', dist.Uniform(0., 1.))
sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
beta = numpyro.sample('beta', dist.HalfNormal(3.))
L = int(np.max(lengths))
with numpyro.plate("y", len(y)):
with numpyro.plate("len", L):