Plate dim error with bernoulli but not normal sample

Here’s the simplified example.

import numpyro
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist
import jax.numpy as jnp
from jax import random


def toy_model(x, y):
    x_offset = numpyro.sample("x_offset", dist.Normal(0, 1), sample_shape=(1, x.shape[1]))
    y_offset = numpyro.sample("y_offset", dist.Normal(0, 1), sample_shape=(1, y.shape[1]))
    weights = numpyro.sample("weights", dist.Normal(0, np.sqrt(x.shape[1])), sample_shape=(x.shape[1], y.shape[1]))
    activation = jnn.sigmoid(x-x_offset)
    is_active = numpyro.sample("is_active", dist.Bernoulli(activation))
    predicted_means = jnp.matmul(is_active, weights) + y_offset
    predictions = numpyro.sample("predictions", dist.Normal(predicted_means, .1), obs=y)

x = np.random.normal(size=(1000, 100))
y = np.stack([x[:, 0] + x[:, 1], x[:, 3] - x[:, 4]], axis=1) + np.random.normal(scale=.1, size=(len(x), 2))

rng_key = random.PRNGKey(np.random.randint(1e6))
num_warmup=1000
num_samples=2000
kernel = NUTS(toy_model)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(rng_key, x=x, y=y)
mcmc.print_summary()
samples = mcmc.get_samples()

produces

ValueError: Missing a plate statement for batch dimension -1 at site ‘x_offset’. You can use numpyro.util.format_shapes utility to check shapes at all sites of your model.

But after changing the Bernoulli sample to dist.Normal(loc=activation) it runs fine. The shapes in both cases are exactly the same:

Trace Shapes:
Param Sites:
Sample Sites:
x_offset dist |
value 1 100 |
y_offset dist |
value 1 2 |
weights dist |
value 100 2 |
is_active dist 1000 100 |
value 1000 100 |
predictions dist 1000 2 |
value 1000 2 |

vs

Trace Shapes:
Param Sites:
Sample Sites:
x_offset dist |
value 1 100 |
y_offset dist |
value 1 2 |
weights dist |
value 100 2 |
is_active dist 1000 100 |
value 1000 100 |
predictions dist 1000 2 |
value 1000 2 |

Hi @JoshDempster, this is expected. Discrete latent variables need to be marginalized out, so we need plate annotation. It is optional for models without discrete latent variables.