Here’s the simplified example.
import numpyro
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist
import jax.numpy as jnp
from jax import random
def toy_model(x, y):
x_offset = numpyro.sample("x_offset", dist.Normal(0, 1), sample_shape=(1, x.shape[1]))
y_offset = numpyro.sample("y_offset", dist.Normal(0, 1), sample_shape=(1, y.shape[1]))
weights = numpyro.sample("weights", dist.Normal(0, np.sqrt(x.shape[1])), sample_shape=(x.shape[1], y.shape[1]))
activation = jnn.sigmoid(x-x_offset)
is_active = numpyro.sample("is_active", dist.Bernoulli(activation))
predicted_means = jnp.matmul(is_active, weights) + y_offset
predictions = numpyro.sample("predictions", dist.Normal(predicted_means, .1), obs=y)
x = np.random.normal(size=(1000, 100))
y = np.stack([x[:, 0] + x[:, 1], x[:, 3] - x[:, 4]], axis=1) + np.random.normal(scale=.1, size=(len(x), 2))
rng_key = random.PRNGKey(np.random.randint(1e6))
num_warmup=1000
num_samples=2000
kernel = NUTS(toy_model)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(rng_key, x=x, y=y)
mcmc.print_summary()
samples = mcmc.get_samples()
produces
ValueError: Missing a plate statement for batch dimension -1 at site ‘x_offset’. You can use
numpyro.util.format_shapes
utility to check shapes at all sites of your model.
But after changing the Bernoulli sample to dist.Normal(loc=activation)
it runs fine. The shapes in both cases are exactly the same:
Trace Shapes:
Param Sites:
Sample Sites:
x_offset dist |
value 1 100 |
y_offset dist |
value 1 2 |
weights dist |
value 100 2 |
is_active dist 1000 100 |
value 1000 100 |
predictions dist 1000 2 |
value 1000 2 |
vs
Trace Shapes:
Param Sites:
Sample Sites:
x_offset dist |
value 1 100 |
y_offset dist |
value 1 2 |
weights dist |
value 100 2 |
is_active dist 1000 100 |
value 1000 100 |
predictions dist 1000 2 |
value 1000 2 |