HMM-like model with sequences of different lengths

I’m trying to use numpyro’s mask with the following model:

z = 0
for i in range(len_seq):
    z ~ Normal (z * w + x[i], sigma)
y ~ Bernoulli (sigmoid(beta * z))

where the sequences don’t have the same length.

I get an unexpected result on the shape of the samples (see below).

#Example for numpyro.handlers.mask

import numpyro
from numpyro import handlers
import numpy
import jax
import jax.numpy as np
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist

def model(y, x_matrix, lengths):
    w = numpyro.sample('w', dist.Uniform(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
    beta = numpyro.sample('beta', dist.HalfNormal(3.))
    
    with numpyro.plate("data", len(y)):
        z = np.zeros(len(y))
        for i in range(lengths.max()):
            with handlers.mask((i<lengths)):
                z = numpyro.sample('z_%d'%i, dist.Normal(z * w + x_matrix[:,i], sigma))
        numpyro.sample('y', dist.Bernoulli(logits=beta*z), obs=y)
        

#Define the variables
y = np.array([1, 0, 0, 1, 1, 0]) #observations
n_obs = len(y)
lengths = np.array([4, 4, 6, 6, 6, 10]) #lengths of the sequences
x_dict = {j : numpy.random.normal(loc=0, scale=1, size=size) for j, size in enumerate(lengths)}

#Transform x_dict into a matrix
x_matrix = numpy.zeros((n_obs, lengths.max()))
for j in range(n_obs):
    x_matrix[j, :len(x_dict[j])] = x_dict[j]
x_matrix = np.array(x_matrix)
    
#Run NUTS
rng_key = jax.random.PRNGKey(0)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=100, num_samples=200)
mcmc.run(rng_key, y=y, x_matrix=x_matrix, lengths=lengths)

mcmc_samples = mcmc.get_samples()
print(mcmc_samples.keys()) #dict_keys(['beta', 'sigma', 'w', 'z_0', 'z_1', 'z_2', 'z_3', 'z_4', 'z_5', 'z_6', 'z_7', 'z_8', 'z_9'])
print(mcmc_samples['z_0'].shape) #(200, 6) as expected
print(mcmc_samples['z_9'].shape) #(200, 6), but I would expect (200,1): there is only one sequence of length 10

That makes me think I didn’t use the mask properly. Any hints?

Thanks for the clear example. When you use handlers.mask(), the resulting mask is only used during the potential energy (log density) computation to neglect some of the entries in z, but you will still see the full matrix captured in the trace, and will need to mask it out manually. To get the masked values in the trace, we will additionally need support for Pyro’s MaskedDistribution. I’ll create a task for this.

My earlier comment about using lax.scan in the inner loop is unlikely to work as is since we are sampling (and recording values) inside the loop. Is the time to compile reasonable with this version?

Update: I made some changes to your model which may result in faster compilation depending on how big the inner loop is (if it is only 10-12 like in the examples, it may not have much impact). There may still be bugs in the code, so please double-check the results of inference. :slight_smile:

  • The first change is just hand-writing the the log density contribution of the inner loop using lax.scan. Numpyro primitives and effect handlers wouldn’t work inside jax control flow primitives, so this is a way around that. Note that we are nullifying the contribution of N(0, 1) in _body_fn. If the inner loop is long (max(lengths) is high), using lax.scan will result in faster compilation time.
  • I am recording the masked out values zs_masked so that we can later look at these using Predictive. This will be much nicer once we have numpyro.deterministic available.
import jax
import jax.numpy as np
import numpy
from jax import lax
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive


def _body_fn(c, val):
    log_prob, w, sigma, lengths = c
    z, i, x_matrix = val
    log_prob = log_prob + np.sum(dist.Normal(z * w + x_matrix, sigma).log_prob(z) * (i < lengths))
    # subtract log_prob contribution from N(0, 1) in the model
    log_prob = log_prob - np.sum(dist.Normal(0., 1.).log_prob(z))
    return (log_prob, w, sigma, lengths), z * (i < lengths)


def model(y, x_matrix, lengths):
    w = numpyro.sample('w', dist.Uniform(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
    beta = numpyro.sample('beta', dist.HalfNormal(3.))
    L = int(np.max(lengths))

    with numpyro.plate("y", len(y)):
        with numpyro.plate("len", L):
            # randomly sampling from N(0, 1) - contribution to PE is subtracted in _body_fn
            zs = numpyro.sample('zs', dist.Normal(0., 1.))
            c, z_masked = lax.scan(_body_fn, (0., w, sigma, lengths), (zs, np.arange(L), x_matrix.T))
            log_prob = c[0]
            numpyro.factor('log_prob', log_prob)
            # No-op: only meant to record these values
            numpyro.sample('zs_masked', dist.Delta(z_masked), z_masked)
        R = numpyro.sample('R', dist.Bernoulli(logits=beta * zs[-1][-1]), obs=y)


# Define the variables
y = np.array([1, 0, 0, 1, 1, 0])  # observations
n_obs = len(y)
lengths = np.array([4, 4, 6, 6, 6, 10])  # lengths of the sequences
x_dict = {j: numpy.random.normal(loc=0, scale=1, size=size) for j, size in enumerate(lengths)}

# Transform x_dict into a matrix
x_matrix = numpy.zeros((n_obs, lengths.max()))
for j in range(n_obs):
    x_matrix[j, :len(x_dict[j])] = x_dict[j]
x_matrix = np.array(x_matrix)

# Run NUTS
rng_key = jax.random.PRNGKey(0)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=400, num_samples=500)
mcmc.run(rng_key, y=y, x_matrix=x_matrix, lengths=lengths)

mcmc_samples = mcmc.get_samples()
# Use Predictive to get values for zs_masked
samples = Predictive(model, mcmc_samples)(random.PRNGKey(1), y=y, x_matrix=x_matrix, lengths=lengths)
# Look at a random sample
print(samples['zs_masked'][100])

Let me know if this solves your issue of slow compilation time.

Thank you, using pyro.handlers.mask speeds up the compilation time compared to without.

About the use of handlers.mask(): in this example, what happens to the variable z of short sequences once i>len_seq? I’m wondering what is the z used in the sampling (Bernoulli distribution). My concern is whether the model is fitting the parameters correctly.

About the use of Pyro’s MaskedDistribution: how would the mask be used in this example? I get the same shape issue with the equivalent Pyro code (see below).

#Example for pyro.poutine.mask

import pyro
from pyro import poutine
import numpy
import jax
import jax.numpy as np
from pyro.infer import MCMC, NUTS
import pyro.distributions as dist

def model(y, x_matrix, lengths):
    w = pyro.sample('w', dist.Uniform(0., 1.))
    sigma = pyro.sample('sigma', dist.HalfNormal(3.))
    beta = pyro.sample('beta', dist.HalfNormal(3.))
    
    with pyro.plate("data", len(y)):
        z = torch.zeros(len(y)).float()
        for i in range(lengths.max()):
            with poutine.mask(mask=i<lengths):
                z = pyro.sample('z_%d'%i, dist.Normal(z * w + x_matrix[:,i], sigma)) #does not work either with: z = pyro.sample('z_%d'%i, dist.Normal(z * w + x_matrix[:,i], sigma).mask(i<lengths))
        pyro.sample('y', dist.Bernoulli(logits=beta*z), obs=y)
        

#Define the variables
y = torch.tensor([1, 0, 0, 1, 1, 0]).float() #observations
n_obs = len(y)
lengths = torch.tensor([4, 4, 6, 6, 6, 10]) #lengths of the sequences
x_dict = {j : torch.randn(size) for j, size in enumerate(lengths)}

#Transform x_dict into a matrix
x_matrix = torch.zeros((n_obs, lengths.max()))
for j in range(n_obs):
    x_matrix[j, :len(x_dict[j])] = x_dict[j]
    
#Run NUTS
kernel = NUTS(model)
mcmc = MCMC(kernel, warmup_steps=15, num_samples=15)
mcmc.run(y=y, x_matrix=x_matrix, lengths=lengths)

mcmc_samples = mcmc.get_samples()
print(mcmc_samples.keys()) #dict_keys(['beta', 'sigma', 'w', 'z_0', 'z_1', 'z_2', 'z_3', 'z_4', 'z_5', 'z_6', 'z_7', 'z_8', 'z_9'])
print(mcmc_samples['z_0'].shape) #(15, 6) as expected
print(mcmc_samples['z_9'].shape) #(15, 6), but I would expect (15,1): there is only one sequence of length 10

About the use of handlers.mask() : in this example, what happens to the variable z of short sequences once i>len_seq ? I’m wondering what is the z used in the sampling (Bernoulli distribution).

Okay, I see the issue. Masking just ensures that the sequences with i > len_seq don’t contribute to the potential energy but z will always have shape = (6,) in your example. So we need to manually apply the mask when considering samples from HMC.

Now I think, in your example, the last sequence has the max length (is that correct?), and therefore, you expect that the z that we finally get is a scalar. If that’s the case, you can simply pull the last value out using dist.Bernoulli(logits=beta * z[-1]). Does that make sense? I have also updated the numpyro snippet above to take this into account.

@neerajprad No, in the Bernoulli sampling z should to be a vector of all the last hidden states (of all sequences). My model is simply an HMM with observation at the end of the sequence:

for i_seq in all_seq:
    len_seq = lengths[i_seq]
    z[i_seq, 0] = 0
    for i in range(len_seq):
        z[i_seq, i+1] ~ Normal (z[i_seq, i] * w + x_matrix[i_seq, i], sigma)
    y[i_seq] ~ Bernoulli (sigmoid(beta * z[i_seq, len_seq]))

and I’m trying to vectorize this model, by using handlers.mask or something else.

I’m getting confused with what “with handlers.mask(i<lengths):” is doing.
In the end I don’t care about the samples of the variable z, I just want the samples of my parameters (w, sigma and beta) to reflect their posterior distribution with my model and data.

1 Like

I’m not sure this is correct. I don’t know about the zs[-1][-1] vs zs[-1]. I don’t understand all the code, but I ran it on simulated data and I wasn’t able to recover the parameters that generated the data. Are you sure that there should be zs in R = numpyro.sample('R', dist.Bernoulli(logits=beta * zs[-1]), obs=y) rather than z_masked?

Thanks for your code snippet, I understand what the model is doing now. I don’t think you need mask in the way we are using it, but I think you could completely vectorize the computation as follows.

Update: I have made multiple revisions to my earlier code based on @vincentbt’s comments and test cases - it verifies the log joint computation between the sequential and vectorized versions now, and we are able to recover the parameter means successfully. The vectorized version should be much faster to compile.

from jax import lax
import jax.numpy as np
from jax import random
import numpy as onp
from numpy.testing import assert_allclose

import numpyro
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import NUTS, MCMC
from numpyro.infer.util import get_potential_fn


def sequential(y, x_matrix, lengths):
    w = numpyro.sample('w', dist.Uniform(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
    beta = numpyro.sample('beta', dist.HalfNormal(3.))
    L = int(np.max(lengths))

    for j, seq in enumerate(x_matrix):
        len_seq = lengths[j]
        z = 0
        for i in range(len_seq):
            z = numpyro.sample('z_{}{}'.format(j, i), dist.Normal(z * w + seq[i], sigma))
        R = numpyro.sample('R_{}'.format(j), dist.Normal(beta * z, 1.), obs=y[j])


def _body_fn(c, val):
    z_prev, w, sigma, lengths = c
    z, i, x_matrix = val
    n = dist.Normal(z_prev * w + x_matrix, sigma)
    # Subtract log prob contribution from N(0, 1) in main body
    log_prob = -dist.Normal(0., 1.).log_prob(z)
    # This is only needed so that the final z vector returned
    # has the last z value for each sequence
    z = np.where(i < lengths, z, z_prev)
    # Add the log_prob contribution for all sequences for the
    # i'th step (mask sequences which are already terminated).
    log_prob += n.log_prob(z) * (i < lengths)
    return (z, w, sigma, lengths), log_prob


def vectorized(y, x_matrix, lengths):
    assert len(lengths) == len(y)
    w = numpyro.sample('w', dist.Uniform(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
    beta = numpyro.sample('beta', dist.HalfNormal(3.))
    L = int(np.max(lengths))

    with numpyro.plate("y", len(y)):
        with numpyro.plate("len", L):
            # randomly sampling from N(0, 1) - contribution to PE is subtracted in _body_fn
            zs = numpyro.sample('zs', dist.Normal(0., 1.))
            c, vals = lax.scan(_body_fn,
                               (np.zeros(len(y)), w, sigma, lengths),
                               (zs, np.arange(L), x_matrix.T))
            z = c[0]
            log_prob = vals
            numpyro.factor('log_prob', np.sum(log_prob))
            numpyro.deterministic('log_p', log_prob)
            numpyro.deterministic('z', z)
        R = numpyro.sample('R', dist.Normal(beta * z, 1.), obs=y)


def generate_data():
    # Simulate data
    onp.random.seed(1)
    sigma = 0.8
    w = 0.69
    beta = 2.7
    lengths = np.array(np.concatenate([[4]*210 + [8]*220 + [12]*215]))
    n_obs = len(lengths)
    L = lengths.max()
    x_matrix = onp.zeros((n_obs, L))

    for i in range(n_obs):
        x_matrix[i, :lengths[i]] = onp.random.normal(0., 1., size=(lengths[i],))

    model_ = handlers.condition(handlers.seed(sequential, 1), {'w': w, 'sigma': sigma, 'beta': beta})
    trace = handlers.trace(model_).get_trace([None] * n_obs, x_matrix, lengths)
    y = np.stack([trace['R_{}'.format(i)]['value'] for i in range(n_obs)])
    z = np.reshape(np.stack([trace['z_{}{}'.format(i, j)]['value'] if j < lengths[i] else 0
                             for i in range(n_obs)
                             for j in range(L)]), (n_obs, L))
    # Return values for both inputs / observed, as well as the latents
    # (for checking potential energy computation)
    return {'sigma': sigma,
            'w': w,
            'beta': beta,
            'lengths': lengths,
            'x_matrix': x_matrix,
            'z': z,
            'y': y,
            'trace': trace}


def assert_seq_vec_potential_energy_match(data):
    # sequential model: check PE (-log_joint) for generated data point.
    seq_sample = {k: v['value'] for k, v in data['trace'].items() if v['type'] == 'sample'}
    seq_pe = get_potential_fn(random.PRNGKey(2), sequential,
                              model_args=(data['y'], data['x_matrix'], data['lengths']),
                              model_kwargs={})[0](seq_sample)
    # vectorized model: check PE (-log_joint) for generated data point.
    vec_pe = get_potential_fn(random.PRNGKey(2), vectorized,
                              model_args=(data['y'], data['x_matrix'], data['lengths']),
                              model_kwargs={})[0](dict(w=data['w'], sigma=data['sigma'], beta=data['beta'],
                                                       zs=data['z'].T))
    assert_allclose(seq_pe, vec_pe, rtol=1e-3)
    print('success: PE is same for vectorized and sequential models.')


def run_inference(model, data):
    rng_key = random.PRNGKey(0)
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000)
    mcmc.run(rng_key, y=data['y'], x_matrix=data['x_matrix'], lengths=data['lengths'])
    mcmc_samples = mcmc.get_samples()
    print(np.mean(mcmc_samples['sigma'], axis=0))
    print(np.mean(mcmc_samples['w'], axis=0))
    print(np.mean(mcmc_samples['beta'], axis=0))


print('generating data...')
data = generate_data()
print('run inference using vectorized model...')
assert_seq_vec_potential_energy_match(data)
run_inference(vectorized, data)
print('run inference using sequential model...')
# WARNING: takes too long to compile
run_inference(sequential, data)

Does that roughly address your question? Happy to chat offline in more detail.

@neerajprad I’m just curious (I haven’t looked into the problem in more detail). I understand that we should use non-centering reparameterization together with lax.scan here for efficiency. But would it be simpler to only use mask on zs and use lax.scan to compute z?

it verifies the log joint computation between the sequential and vectorized versions

I think with non-centering reparameterization, we want to compute p(zs, …), so the difference between two versions should be fine. WDYT?

1 Like

I understand that we should use non-centering reparameterization together with lax.scan here for efficiency.

Great point, yes, I think using a non-centered parametrization in _body_fn will be much better! My goal here was only to verify that the model could be rewritten using jax primitives to enable fast compilation.

But would it be simpler to only use mask on zs and use lax.scan to compute z ?

I’m not sure if that will be much simpler, though I might be misunderstanding your suggestion. The mask operation is just a small detail in _body_fn.

I think with non-centering reparameterization, we want to compute p(zs, …), so the difference between two versions should be fine.

Yes, in that case the log density computation would differ (unless we change the sequential version accordingly) but that should be fine. We only have this assertion to convince ourselves that the vectorized version is indeed doing the right thing.

Understood! I thought you were using reparameterization, but after looking into the solution again, I understand that you just want to construct a corresponding one for the sequential model.

For users who are interested in the non-centered reparameterization which I mentioned above, here is the corresponding model (slight modification of @neerajprad’s answer). This technique will mix much better than the centered one. Note that future versions of JAX will include grad rule for while loop, which will allow us to use lax.fori_loop instead of lax.scan here (not important though).

def scan_z(noises, x_matrix, w, lengths):
    def _body_fn(z_prev, val):
        noise, i, x_col = val
        z = z_prev * w + x_col + noise
        z = np.where(i < lengths, z, z_prev)
        return z, None

    return lax.scan(_body_fn, np.zeros(noises.shape[1]),
                    (noises, np.arange(noises.shape[0]), x_matrix.T))[0]


def reparam_model(y, x_matrix, lengths):
    w = numpyro.sample('w', dist.Uniform(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
    beta = numpyro.sample('beta', dist.HalfNormal(3.))
    L = int(np.max(lengths))
    with numpyro.plate("y", len(y)):
        with numpyro.plate("len", L):
            with handlers.mask(np.arange(L)[..., None] < lengths):
                noises = sigma * numpyro.sample('noises', dist.Normal(0., 1.))
            z = scan_z(noises, x_matrix, w, lengths)
            numpyro.deterministic('z', z)
        numpyro.sample('R', dist.Normal(beta * z, 1.), obs=y)
1 Like

Whoa…that’s much simpler, thanks for adding this non-centered version! Now I see what you meant by masking outside the scan body. @vincentbt - you probably want to use something like @fehiepsi’s version above for your models, which will mix much better.

1 Like

Thanks to both of you for your help!
The models you propose compile fast and parameter recovery works. And the reparam_model is really simple to understand!