I’m trying to use numpyro’s mask with the following model:

```
z = 0
for i in range(len_seq):
z ~ Normal (z * w + x[i], sigma)
y ~ Bernoulli (sigmoid(beta * z))
```

where the sequences don’t have the same length.

I get an unexpected result on the shape of the samples (see below).

```
#Example for numpyro.handlers.mask
import numpyro
from numpyro import handlers
import numpy
import jax
import jax.numpy as np
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist
def model(y, x_matrix, lengths):
w = numpyro.sample('w', dist.Uniform(0., 1.))
sigma = numpyro.sample('sigma', dist.HalfNormal(3.))
beta = numpyro.sample('beta', dist.HalfNormal(3.))
with numpyro.plate("data", len(y)):
z = np.zeros(len(y))
for i in range(lengths.max()):
with handlers.mask((i<lengths)):
z = numpyro.sample('z_%d'%i, dist.Normal(z * w + x_matrix[:,i], sigma))
numpyro.sample('y', dist.Bernoulli(logits=beta*z), obs=y)
#Define the variables
y = np.array([1, 0, 0, 1, 1, 0]) #observations
n_obs = len(y)
lengths = np.array([4, 4, 6, 6, 6, 10]) #lengths of the sequences
x_dict = {j : numpy.random.normal(loc=0, scale=1, size=size) for j, size in enumerate(lengths)}
#Transform x_dict into a matrix
x_matrix = numpy.zeros((n_obs, lengths.max()))
for j in range(n_obs):
x_matrix[j, :len(x_dict[j])] = x_dict[j]
x_matrix = np.array(x_matrix)
#Run NUTS
rng_key = jax.random.PRNGKey(0)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=100, num_samples=200)
mcmc.run(rng_key, y=y, x_matrix=x_matrix, lengths=lengths)
mcmc_samples = mcmc.get_samples()
print(mcmc_samples.keys()) #dict_keys(['beta', 'sigma', 'w', 'z_0', 'z_1', 'z_2', 'z_3', 'z_4', 'z_5', 'z_6', 'z_7', 'z_8', 'z_9'])
print(mcmc_samples['z_0'].shape) #(200, 6) as expected
print(mcmc_samples['z_9'].shape) #(200, 6), but I would expect (200,1): there is only one sequence of length 10
```

That makes me think I didn’t use the mask properly. Any hints?