HMM-like model with sequences of different lengths

Understood! I thought you were using reparameterization, but after looking into the solution again, I understand that you just want to construct a corresponding one for the sequential model.

For users who are interested in the non-centered reparameterization which I mentioned above, here is the corresponding model (slight modification of @neerajprad’s answer). This technique will mix much better than the centered one. Note that future versions of JAX will include grad rule for while loop, which will allow us to use lax.fori_loop instead of lax.scan here (not important though).

def scan_z(noises, x_matrix, w, lengths):
    def _body_fn(z_prev, val):
        noise, i, x_col = val
        z = z_prev * w + x_col + noise
        z = np.where(i < lengths, z, z_prev)
        return z, None

    return lax.scan(_body_fn, np.zeros(noises.shape[1]),
                    (noises, np.arange(noises.shape[0]), x_matrix.T))[0]


def reparam_model(y, x_matrix, lengths):
    w = numpyro.sample('w', dist.Uniform(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
    beta = numpyro.sample('beta', dist.HalfNormal(3.))
    L = int(np.max(lengths))
    with numpyro.plate("y", len(y)):
        with numpyro.plate("len", L):
            with handlers.mask(np.arange(L)[..., None] < lengths):
                noises = sigma * numpyro.sample('noises', dist.Normal(0., 1.))
            z = scan_z(noises, x_matrix, w, lengths)
            numpyro.deterministic('z', z)
        numpyro.sample('R', dist.Normal(beta * z, 1.), obs=y)
1 Like