HMM-like model with sequences of different lengths

I’m trying to use numpyro’s mask with the following model:

z = 0
for i in range(len_seq):
    z ~ Normal (z * w + x[i], sigma)
y ~ Bernoulli (sigmoid(beta * z))

where the sequences don’t have the same length.

I get an unexpected result on the shape of the samples (see below).

#Example for numpyro.handlers.mask

import numpyro
from numpyro import handlers
import numpy
import jax
import jax.numpy as np
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist

def model(y, x_matrix, lengths):
    w = numpyro.sample('w', dist.Uniform(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
    beta = numpyro.sample('beta', dist.HalfNormal(3.))
    
    with numpyro.plate("data", len(y)):
        z = np.zeros(len(y))
        for i in range(lengths.max()):
            with handlers.mask((i<lengths)):
                z = numpyro.sample('z_%d'%i, dist.Normal(z * w + x_matrix[:,i], sigma))
        numpyro.sample('y', dist.Bernoulli(logits=beta*z), obs=y)
        

#Define the variables
y = np.array([1, 0, 0, 1, 1, 0]) #observations
n_obs = len(y)
lengths = np.array([4, 4, 6, 6, 6, 10]) #lengths of the sequences
x_dict = {j : numpy.random.normal(loc=0, scale=1, size=size) for j, size in enumerate(lengths)}

#Transform x_dict into a matrix
x_matrix = numpy.zeros((n_obs, lengths.max()))
for j in range(n_obs):
    x_matrix[j, :len(x_dict[j])] = x_dict[j]
x_matrix = np.array(x_matrix)
    
#Run NUTS
rng_key = jax.random.PRNGKey(0)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=100, num_samples=200)
mcmc.run(rng_key, y=y, x_matrix=x_matrix, lengths=lengths)

mcmc_samples = mcmc.get_samples()
print(mcmc_samples.keys()) #dict_keys(['beta', 'sigma', 'w', 'z_0', 'z_1', 'z_2', 'z_3', 'z_4', 'z_5', 'z_6', 'z_7', 'z_8', 'z_9'])
print(mcmc_samples['z_0'].shape) #(200, 6) as expected
print(mcmc_samples['z_9'].shape) #(200, 6), but I would expect (200,1): there is only one sequence of length 10

That makes me think I didn’t use the mask properly. Any hints?

Note that this code uses numpyro.handlers.mask̀, which is not on pypi yet but corresponds to the current master on Github. This version can be downloaded with: pip install git+https://github.com/pyro-ppl/numpyro@bb90d6ef657f19842f54baff37f3d20c7ca6f10a (see this link).

Thanks for the clear example. When you use handlers.mask(), the resulting mask is only used during the potential energy (log density) computation to neglect some of the entries in z, but you will still see the full matrix captured in the trace, and will need to mask it out manually. To get the masked values in the trace, we will additionally need support for Pyro’s MaskedDistribution. I’ll create a task for this.

My earlier comment about using lax.scan in the inner loop is unlikely to work as is since we are sampling (and recording values) inside the loop. Is the time to compile reasonable with this version?

Update: I made some changes to your model which may result in faster compilation depending on how big the inner loop is (if it is only 10-12 like in the examples, it may not have much impact). There may still be bugs in the code, so please double-check the results of inference. :slight_smile:

  • The first change is just hand-writing the the log density contribution of the inner loop using lax.scan. Numpyro primitives and effect handlers wouldn’t work inside jax control flow primitives, so this is a way around that. Note that we are nullifying the contribution of N(0, 1) in _body_fn. If the inner loop is long (max(lengths) is high), using lax.scan will result in faster compilation time.
  • I am recording the masked out values zs_masked so that we can later look at these using Predictive. This will be much nicer once we have numpyro.deterministic available.
import jax
import jax.numpy as np
import numpy
from jax import lax
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive


def _body_fn(c, val):
    log_prob, w, sigma, lengths = c
    z, i, x_matrix = val
    log_prob = log_prob + np.sum(dist.Normal(z * w + x_matrix, sigma).log_prob(z) * (i < lengths))
    # subtract log_prob contribution from N(0, 1) in the model
    log_prob = log_prob - np.sum(dist.Normal(0., 1.).log_prob(z))
    return (log_prob, w, sigma, lengths), z * (i < lengths)


def model(y, x_matrix, lengths):
    w = numpyro.sample('w', dist.Uniform(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
    beta = numpyro.sample('beta', dist.HalfNormal(3.))
    L = int(np.max(lengths))

    with numpyro.plate("y", len(y)):
        with numpyro.plate("len", L):
            # randomly sampling from N(0, 1) - contribution to PE is subtracted in _body_fn
            zs = numpyro.sample('zs', dist.Normal(0., 1.))
            c, z_masked = lax.scan(_body_fn, (0., w, sigma, lengths), (zs, np.arange(L), x_matrix.T))
            log_prob = c[0]
            numpyro.factor('log_prob', log_prob)
            # No-op: only meant to record these values
            numpyro.sample('zs_masked', dist.Delta(z_masked), z_masked)
        R = numpyro.sample('R', dist.Bernoulli(logits=beta * zs[-1][-1]), obs=y)


# Define the variables
y = np.array([1, 0, 0, 1, 1, 0])  # observations
n_obs = len(y)
lengths = np.array([4, 4, 6, 6, 6, 10])  # lengths of the sequences
x_dict = {j: numpy.random.normal(loc=0, scale=1, size=size) for j, size in enumerate(lengths)}

# Transform x_dict into a matrix
x_matrix = numpy.zeros((n_obs, lengths.max()))
for j in range(n_obs):
    x_matrix[j, :len(x_dict[j])] = x_dict[j]
x_matrix = np.array(x_matrix)

# Run NUTS
rng_key = jax.random.PRNGKey(0)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=400, num_samples=500)
mcmc.run(rng_key, y=y, x_matrix=x_matrix, lengths=lengths)

mcmc_samples = mcmc.get_samples()
# Use Predictive to get values for zs_masked
samples = Predictive(model, mcmc_samples)(random.PRNGKey(1), y=y, x_matrix=x_matrix, lengths=lengths)
# Look at a random sample
print(samples['zs_masked'][100])

Let me know if this solves your issue of slow compilation time.

Thank you, using pyro.handlers.mask speeds up the compilation time compared to without.

About the use of handlers.mask(): in this example, what happens to the variable z of short sequences once i>len_seq? I’m wondering what is the z used in the sampling (Bernoulli distribution). My concern is whether the model is fitting the parameters correctly.

About the use of Pyro’s MaskedDistribution: how would the mask be used in this example? I get the same shape issue with the equivalent Pyro code (see below).

#Example for pyro.poutine.mask

import pyro
from pyro import poutine
import numpy
import jax
import jax.numpy as np
from pyro.infer import MCMC, NUTS
import pyro.distributions as dist

def model(y, x_matrix, lengths):
    w = pyro.sample('w', dist.Uniform(0., 1.))
    sigma = pyro.sample('sigma', dist.HalfNormal(3.))
    beta = pyro.sample('beta', dist.HalfNormal(3.))
    
    with pyro.plate("data", len(y)):
        z = torch.zeros(len(y)).float()
        for i in range(lengths.max()):
            with poutine.mask(mask=i<lengths):
                z = pyro.sample('z_%d'%i, dist.Normal(z * w + x_matrix[:,i], sigma)) #does not work either with: z = pyro.sample('z_%d'%i, dist.Normal(z * w + x_matrix[:,i], sigma).mask(i<lengths))
        pyro.sample('y', dist.Bernoulli(logits=beta*z), obs=y)
        

#Define the variables
y = torch.tensor([1, 0, 0, 1, 1, 0]).float() #observations
n_obs = len(y)
lengths = torch.tensor([4, 4, 6, 6, 6, 10]) #lengths of the sequences
x_dict = {j : torch.randn(size) for j, size in enumerate(lengths)}

#Transform x_dict into a matrix
x_matrix = torch.zeros((n_obs, lengths.max()))
for j in range(n_obs):
    x_matrix[j, :len(x_dict[j])] = x_dict[j]
    
#Run NUTS
kernel = NUTS(model)
mcmc = MCMC(kernel, warmup_steps=15, num_samples=15)
mcmc.run(y=y, x_matrix=x_matrix, lengths=lengths)

mcmc_samples = mcmc.get_samples()
print(mcmc_samples.keys()) #dict_keys(['beta', 'sigma', 'w', 'z_0', 'z_1', 'z_2', 'z_3', 'z_4', 'z_5', 'z_6', 'z_7', 'z_8', 'z_9'])
print(mcmc_samples['z_0'].shape) #(15, 6) as expected
print(mcmc_samples['z_9'].shape) #(15, 6), but I would expect (15,1): there is only one sequence of length 10

About the use of handlers.mask() : in this example, what happens to the variable z of short sequences once i>len_seq ? I’m wondering what is the z used in the sampling (Bernoulli distribution).

Okay, I see the issue. Masking just ensures that the sequences with i > len_seq don’t contribute to the potential energy but z will always have shape = (6,) in your example. So we need to manually apply the mask when considering samples from HMC.

Now I think, in your example, the last sequence has the max length (is that correct?), and therefore, you expect that the z that we finally get is a scalar. If that’s the case, you can simply pull the last value out using dist.Bernoulli(logits=beta * z[-1]). Does that make sense? I have also updated the numpyro snippet above to take this into account.

@neerajprad No, in the Bernoulli sampling z should to be a vector of all the last hidden states (of all sequences). My model is simply an HMM with observation at the end of the sequence:

for i_seq in all_seq:
    len_seq = lengths[i_seq]
    z[i_seq, 0] = 0
    for i in range(len_seq):
        z[i_seq, i+1] ~ Normal (z[i_seq, i] * w + x_matrix[i_seq, i], sigma)
    y[i_seq] ~ Bernoulli (sigmoid(beta * z[i_seq, len_seq]))

and I’m trying to vectorize this model, by using handlers.mask or something else.

I’m getting confused with what “with handlers.mask(i<lengths):” is doing.
In the end I don’t care about the samples of the variable z, I just want the samples of my parameters (w, sigma and beta) to reflect their posterior distribution with my model and data.

1 Like

I’m not sure this is correct. I don’t know about the zs[-1][-1] vs zs[-1]. I don’t understand all the code, but I ran it on simulated data and I wasn’t able to recover the parameters that generated the data. Are you sure that there should be zs in R = numpyro.sample('R', dist.Bernoulli(logits=beta * zs[-1]), obs=y) rather than z_masked?

Thanks for your code snippet, I understand what the model is doing now. I don’t think you need mask in the way we are using it, but I think you could completely vectorize the computation as follows. This is just to convey the idea, I may have messed up the log density computation, so please double check:

Update: I have edited the code below. There were serious bugs in the previous code - we were reusing the same source of randomness and the z’s were not being sampled by HMC. This also uses deterministic for which you’ll need to use NumPyro’s master branch.

import jax
import jax.numpy as np
import numpy
from jax import lax

import numpyro
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS


def _body_fn(c, val):
    z_prev, w, sigma, lengths = c
    z, i, x_matrix = val
    n = dist.Normal(z_prev * w + x_matrix, sigma)
    z = np.where(i < lengths, z, z_prev)
    log_prob = (n.log_prob(z) - dist.Normal(0., 1.).log_prob(z)) * (i < lengths)
    return (z, w, sigma, lengths), log_prob


def model(y, x_matrix, lengths):
    w = numpyro.sample('w', dist.Uniform(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
    beta = numpyro.sample('beta', dist.HalfNormal(3.))
    L = int(np.max(lengths))

    with numpyro.plate("y", len(y)):
        with numpyro.plate("len", L):
            # randomly sampling from N(0, 1) - contribution to PE is subtracted in _body_fn
            zs = numpyro.sample('zs', dist.Normal(0., 1.))
            c, vals = lax.scan(_body_fn,
                               (np.zeros(len(y)), w, sigma, lengths),
                               (zs, np.arange(L), x_matrix.T))
            z = c[0]
            log_prob = vals
            numpyro.factor('log_prob', np.sum(log_prob))
            numpyro.deterministic('log_p', log_prob)
            numpyro.deterministic('z', z)
        R = numpyro.sample('R', dist.Bernoulli(logits=beta * z), obs=y)


# Define the variables
y = np.array([1, 0, 0, 1, 1, 0])  # observations
n_obs = len(y)
lengths = np.array([4, 4, 6, 6, 6, 10])  # lengths of the sequences
x_dict = {j: numpy.random.normal(loc=0, scale=1, size=size) for j, size in enumerate(lengths)}

# Transform x_dict into a matrix
x_matrix = numpy.zeros((n_obs, lengths.max()))
for j in range(n_obs):
    x_matrix[j, :len(x_dict[j])] = x_dict[j]
x_matrix = np.array(x_matrix)

# Run NUTS
rng_key = jax.random.PRNGKey(0)
kernel = NUTS(handlers.block(handlers.seed(model, rng_key),
                             hide_fn=lambda site: site['name'] == 'rng_key'))
mcmc = MCMC(kernel, num_warmup=400, num_samples=600)
mcmc.run(rng_key, y=y, x_matrix=x_matrix, lengths=lengths)

mcmc_samples = mcmc.get_samples()
print(np.mean(mcmc_samples['sigma'], axis=0))
print(np.mean(mcmc_samples['w'], axis=0))
print(np.mean(mcmc_samples['beta'], axis=0))
print(mcmc_samples['log_p'][10])
print(mcmc_samples['z'][10])

Does that roughly address your question? Happy to chat offline in more detail.