When I want to conduct inference on my model using SVI, I use the “init_to_mean” strategy. My understanding is that all parameters are initialized to their mean, and if they don’t have a mean, the “init_to_median” strategy will be used. However, when I inspect out the value of a parameter during this initialization phase, it is not as is expected. For example, I would expect a parameter with a Normal(0.0, 1.0) prior to be initialized to 0.0, but I’m seeing this is not the case actually. Here is a minimal working example.
import jax
import numpyro
import numpyro.distributions as dist
from numpyro.infer.autoguide import AutoDelta
from numpyro.infer import init_to_mean, SVI, Trace_ELBO
from numpyro.optim import Adam
def model():
numpyro.sample("x", dist.Normal(0, 1))
# Set up AutoDelta with init_to_mean
guide = AutoDelta(model, init_loc_fn= init_to_mean())
optimizer = Adam(0.01)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
rng_key = jax.random.PRNGKey(0)
svi_state = svi.init(rng_key)
# Access raw param value for x
print("Initial point estimate for x:", svi.get_params(svi_state))