Understood! I thought you were using reparameterization, but after looking into the solution again, I understand that you just want to construct a corresponding one for the sequential model.
For users who are interested in the non-centered reparameterization which I mentioned above, here is the corresponding model (slight modification of @neerajprad’s answer). This technique will mix much better than the centered one. Note that future versions of JAX will include grad rule for while loop, which will allow us to use lax.fori_loop
instead of lax.scan
here (not important though).
def scan_z(noises, x_matrix, w, lengths):
def _body_fn(z_prev, val):
noise, i, x_col = val
z = z_prev * w + x_col + noise
z = np.where(i < lengths, z, z_prev)
return z, None
return lax.scan(_body_fn, np.zeros(noises.shape[1]),
(noises, np.arange(noises.shape[0]), x_matrix.T))[0]
def reparam_model(y, x_matrix, lengths):
w = numpyro.sample('w', dist.Uniform(0., 1.))
sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
beta = numpyro.sample('beta', dist.HalfNormal(3.))
L = int(np.max(lengths))
with numpyro.plate("y", len(y)):
with numpyro.plate("len", L):
with handlers.mask(np.arange(L)[..., None] < lengths):
noises = sigma * numpyro.sample('noises', dist.Normal(0., 1.))
z = scan_z(noises, x_matrix, w, lengths)
numpyro.deterministic('z', z)
numpyro.sample('R', dist.Normal(beta * z, 1.), obs=y)