Hello, I’m trying to fit a linear random-effects model using numpyro and have been comparing the autoguide vs my custom guide to assess inference stability. I am continually running into issues with the loss increasing, then decreasing under my custom guide, whereas the auto-guide does not exhibit these issues.
My guide should roughly resemble the auto-guide (mostly a meanfield with exception for a few separate parameters), so I’m a bit confused as how this can happen.
Here are losses at various epochs under my guide vs the auto-guide:
My-Guide
iter 0 - loss = 7015422.0
iter 25 - loss = 59737.24609375
iter 50 - loss = 36599.234375
iter 75 - loss = 25873.19921875
iter 100 - loss = 24866.625
iter 125 - loss = 7331.56005859375
iter 150 - loss = 25220.345703125
Auto-Guide
iter 0 - loss = 4951479.5
iter 25 - loss = 193485.234375
iter 50 - loss = 112910.3828125
iter 75 - loss = 85999.2421875
iter 100 - loss = 62530.828125
iter 125 - loss = 52520.109375
iter 150 - loss = 42750.94140625
Here are the model and guide definitions.
def model(X_1: jnp.ndarray, W_1: jnp.ndarray, y_1: jnp.ndarray,
X_2: jnp.ndarray, W_2: jnp.ndarray, y_2: jnp.ndarray = None) -> None:
n_1, p = X_1.shape
n_2, p = X_2.shape
# coupling parameter
s_0 = 1
s = numpyro.sample("s", dist.MultivariateNormal(0., s_0 * jnp.eye(2)))
# prior variance
sigma_b = numpyro.param("sigma_b", jnp.sqrt(0.3 / p), constrain=constraints.positive)
# effect sizes;
with numpyro.plate("beta_i", p):
beta_1 = numpyro.sample("beta_1", dist.Normal(0., W_1 ** (s[0]/2.) * sigma_b))
beta_2 = numpyro.sample("beta_2", dist.Normal(0., W_2 ** (s[1]/2.) * sigma_b))
# environmental var
sigma_e1 = numpyro.param("sigma_e1", 1., constrain=constraints.positive)
sigma_e2 = numpyro.param("sigma_e2", 1., constrain=constraints.positive)
# likelihood/data generation
mu_1 = jnp.dot(X_1, beta_1)
mu_2 = jnp.dot(X_2, beta_2)
with numpyro.plate("data", n_1):
numpyro.sample("y_1", dist.Normal(mu_1, sigma_e1), obs=y_1)
numpyro.sample("y_2", dist.Normal(mu_2, sigma_e2), obs=y_2)
return
def guide(X_1: jnp.ndarray, W_1: jnp.ndarray, y_1: jnp.ndarray,
X_2: jnp.ndarray, W_2: jnp.ndarray, y_2: jnp.ndarray = None) -> None:
n_1, p = X_1.shape
n_2, p = X_2.shape
# mean / sd for parameter s
s_loc = numpyro.param("s_loc", jnp.zeros(2))
s_scale = numpyro.param("s_scale", jnp.eye(2), constrain=constraints.positive_definite)
# approximate multivariate-normal posterior
s = numpyro.sample("s", dist.MultivariateNormal(s_loc, s_scale))
# posterior means for betas
beta_1_loc = numpyro.param("beta_1_loc", jnp.zeros(p))
beta_2_loc = numpyro.param("beta_2_loc", jnp.zeros(p))
# posterior sd for betas
beta_1_scale = jnp.exp(numpyro.param("beta_1_scale", jnp.ones(p) / n_1))
beta_2_scale = jnp.exp(numpyro.param("beta_2_scale", jnp.ones(p) / n_2))
with numpyro.plate("beta_i", p):
# mean-field approximation that all variants are independent in normal posterior
beta_1 = numpyro.sample("beta_1", dist.Normal(beta_1_loc, beta_1_scale))
beta_2 = numpyro.sample("beta_2", dist.Normal(beta_2_loc, beta_2_scale))
return
What could be going on here?