How do I use nested plate notation for dataset that looks like this in numpyro?

@mathlad as in e.g. the CJS example, you can use numpyro.handlers.mask, which works identically to pyro.poutine.mask, to ignore missing parameters and observations in your example:

def model(data, levels_mask, observations_mask):
    ...
    with numpyro.plate("n_levels", n_levels, dim=-3):
        a_level = numpyro.sample("a_level", dist.HalfNormal(a_level_global_scale))
        b_level = numpyro.sample("b_level", dist.HalfNormal(b_level_global_scale))

        with numpyro.plate("n_participants", n_participants, dim=-2), \
                numpyro.handlers.mask(mask=levels_mask):
            a = numpyro.sample("a", dist.Normal(a_level, a_global_scale))
            b = numpyro.sample("b", dist.Normal(b_level, b_global_scale))

            with numpyro.plate("n_observations", max_n_observations, dim=-1), \
                    numpyro.handlers.mask(mask=observations_mask):
                return numpyro.sample("output", ..., obs=data)

This is statistically optimal, and while it may be wasteful computationally, if your data is not too sparse the parallelism inherent in this approach is probably worth it.

See also this forum thread for a longer and better-explained Pyro example, or search for others that use mask to encode patterns of sparsity.

1 Like