HMM-like model with sequences of different lengths

Thanks for your code snippet, I understand what the model is doing now. I don’t think you need mask in the way we are using it, but I think you could completely vectorize the computation as follows.

Update: I have made multiple revisions to my earlier code based on @vincentbt’s comments and test cases - it verifies the log joint computation between the sequential and vectorized versions now, and we are able to recover the parameter means successfully. The vectorized version should be much faster to compile.

from jax import lax
import jax.numpy as np
from jax import random
import numpy as onp
from numpy.testing import assert_allclose

import numpyro
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import NUTS, MCMC
from numpyro.infer.util import get_potential_fn


def sequential(y, x_matrix, lengths):
    w = numpyro.sample('w', dist.Uniform(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
    beta = numpyro.sample('beta', dist.HalfNormal(3.))
    L = int(np.max(lengths))

    for j, seq in enumerate(x_matrix):
        len_seq = lengths[j]
        z = 0
        for i in range(len_seq):
            z = numpyro.sample('z_{}{}'.format(j, i), dist.Normal(z * w + seq[i], sigma))
        R = numpyro.sample('R_{}'.format(j), dist.Normal(beta * z, 1.), obs=y[j])


def _body_fn(c, val):
    z_prev, w, sigma, lengths = c
    z, i, x_matrix = val
    n = dist.Normal(z_prev * w + x_matrix, sigma)
    # Subtract log prob contribution from N(0, 1) in main body
    log_prob = -dist.Normal(0., 1.).log_prob(z)
    # This is only needed so that the final z vector returned
    # has the last z value for each sequence
    z = np.where(i < lengths, z, z_prev)
    # Add the log_prob contribution for all sequences for the
    # i'th step (mask sequences which are already terminated).
    log_prob += n.log_prob(z) * (i < lengths)
    return (z, w, sigma, lengths), log_prob


def vectorized(y, x_matrix, lengths):
    assert len(lengths) == len(y)
    w = numpyro.sample('w', dist.Uniform(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
    beta = numpyro.sample('beta', dist.HalfNormal(3.))
    L = int(np.max(lengths))

    with numpyro.plate("y", len(y)):
        with numpyro.plate("len", L):
            # randomly sampling from N(0, 1) - contribution to PE is subtracted in _body_fn
            zs = numpyro.sample('zs', dist.Normal(0., 1.))
            c, vals = lax.scan(_body_fn,
                               (np.zeros(len(y)), w, sigma, lengths),
                               (zs, np.arange(L), x_matrix.T))
            z = c[0]
            log_prob = vals
            numpyro.factor('log_prob', np.sum(log_prob))
            numpyro.deterministic('log_p', log_prob)
            numpyro.deterministic('z', z)
        R = numpyro.sample('R', dist.Normal(beta * z, 1.), obs=y)


def generate_data():
    # Simulate data
    onp.random.seed(1)
    sigma = 0.8
    w = 0.69
    beta = 2.7
    lengths = np.array(np.concatenate([[4]*210 + [8]*220 + [12]*215]))
    n_obs = len(lengths)
    L = lengths.max()
    x_matrix = onp.zeros((n_obs, L))

    for i in range(n_obs):
        x_matrix[i, :lengths[i]] = onp.random.normal(0., 1., size=(lengths[i],))

    model_ = handlers.condition(handlers.seed(sequential, 1), {'w': w, 'sigma': sigma, 'beta': beta})
    trace = handlers.trace(model_).get_trace([None] * n_obs, x_matrix, lengths)
    y = np.stack([trace['R_{}'.format(i)]['value'] for i in range(n_obs)])
    z = np.reshape(np.stack([trace['z_{}{}'.format(i, j)]['value'] if j < lengths[i] else 0
                             for i in range(n_obs)
                             for j in range(L)]), (n_obs, L))
    # Return values for both inputs / observed, as well as the latents
    # (for checking potential energy computation)
    return {'sigma': sigma,
            'w': w,
            'beta': beta,
            'lengths': lengths,
            'x_matrix': x_matrix,
            'z': z,
            'y': y,
            'trace': trace}


def assert_seq_vec_potential_energy_match(data):
    # sequential model: check PE (-log_joint) for generated data point.
    seq_sample = {k: v['value'] for k, v in data['trace'].items() if v['type'] == 'sample'}
    seq_pe = get_potential_fn(random.PRNGKey(2), sequential,
                              model_args=(data['y'], data['x_matrix'], data['lengths']),
                              model_kwargs={})[0](seq_sample)
    # vectorized model: check PE (-log_joint) for generated data point.
    vec_pe = get_potential_fn(random.PRNGKey(2), vectorized,
                              model_args=(data['y'], data['x_matrix'], data['lengths']),
                              model_kwargs={})[0](dict(w=data['w'], sigma=data['sigma'], beta=data['beta'],
                                                       zs=data['z'].T))
    assert_allclose(seq_pe, vec_pe, rtol=1e-3)
    print('success: PE is same for vectorized and sequential models.')


def run_inference(model, data):
    rng_key = random.PRNGKey(0)
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000)
    mcmc.run(rng_key, y=data['y'], x_matrix=data['x_matrix'], lengths=data['lengths'])
    mcmc_samples = mcmc.get_samples()
    print(np.mean(mcmc_samples['sigma'], axis=0))
    print(np.mean(mcmc_samples['w'], axis=0))
    print(np.mean(mcmc_samples['beta'], axis=0))


print('generating data...')
data = generate_data()
print('run inference using vectorized model...')
assert_seq_vec_potential_energy_match(data)
run_inference(vectorized, data)
print('run inference using sequential model...')
# WARNING: takes too long to compile
run_inference(sequential, data)

Does that roughly address your question? Happy to chat offline in more detail.