Thanks @ordabayev. For one, numpyro is not happy about the different shapes of the prior for b (i.e. shape=()), and approximate conditional posterior for b (i.e. shape(p,)). Fixing that to use a flat, shared prior over b still results in non-sensical inference. For example,
def model(X, y, l_dim=1):
n_dim, p_dim = X.shape
pi = jnp.ones(p_dim) / float(p_dim)
sigma_b = jnp.ones(p_dim) * 1e-3
gamma = numpyro.sample("gamma", dist.Categorical(pi))
b = numpyro.sample("b", dist.Normal(0.0, sigma_b[..., gamma]))
sigma_e = numpyro.param("sigma_e", 0.9, constraint=constraints.positive)
with numpyro.plate("N", n_dim):
g = Vindex(X)[:, gamma] * b
numpyro.sample("y", dist.Normal(g, sigma_e), obs=y)
return
def guide(X, y, l_dim=1):
n_dim, p_dim = X.shape
alpha = numpyro.param(
"alpha", jnp.ones(p_dim) / float(p_dim), constraints=constraints.simplex
)
gamma = numpyro.sample(
"gamma", dist.Categorical(alpha), infer={"enumerate": "parallel"}
)
post_sigma_b = numpyro.param(
"post_sigma_b", jnp.ones(p_dim) * 1e-3, constraints=constraints.positive
)
post_mu_b = numpyro.param("post_mu_b", jnp.zeros(p_dim))
b = numpyro.sample(
"b",
dist.Normal(Vindex(post_mu_b)[..., gamma], Vindex(post_sigma_b)[..., gamma]),
)
return
#[...]
adam = optim.Adam(step_size=0.005)
svi = SVI(model, guide, adam, TraceEnum_ELBO(max_plate_nesting=10))
results = svi.run(
rng_key_run,
args.epochs,
X=X,
y=y,
l_dim=args.l_dim,
progress_bar=True,
stable_update=True,
)
Here is some summarized output showing nan
due to incorrect posterior std inference:
Param Sites:
sigma_e
Sample Sites:
gamma dist |
value |
log_prob |
b dist |
value |
log_prob |
N plate 400 |
y dist 400 |
value 400 |
log_prob 400 |
results.losses = Array([15483.41834474, nan, nan, nan,
nan, nan, nan, nan,
nan, nan], dtype=float64, weak_type=True)
results.params["post_sigma_b"] = Array([-0.004, -0.004, 0.006, -0.004, 0.006, -0.004, 0.006, 0.006,
0.006, -0.004, -0.004, 0.006, -0.004, -0.004, 0.006, -0.004,
-0.004, -0.004, 0.006, 0.006], dtype=float64)
I’ve uploaded the entire simulation and sample code here.