I am trying to use numpyro to fit a large model with ~1500 latent parameters with a dense covariance matrix. To try and help NUTS along I have broken the problem into several stages:
- Use SVI with
AutoDelta
to find a decent MAP solution (works well) - Use this MAP solution as an initial location for
AutoMultivariateNormal
to estimate the dense covariance matrix - Use this covariance matrix as the
inverse_mass_matrix
for NUTS so it does not have to learn it in warmup. As the true posterior is not a MVN, use NUTS to get a more accurate sample.
My questions:
- Are the
auto_scale_tril
values from theAutoMultivariateNormal
in the same order as NUTS’sinverse_mass_matrix
when applied to the same model? - Are the scales the same?
- Do I need to somehow get the “unscaled” cov matrix out of
AutoMultivariateNormal
before trying to pass it forward?
When I tried this I ended up with all draws divergent.
Pseudo code:
def model(data):
# large model with ~1500 parameters
# (estimating pixel value in a 39x39 image and ~20 model parameters on top of that
...
# Find MAP value
guide_delta = AutoDelta(model, init_loc_fn=infer.init_to_sample())
svi_delta = SVI(
reg_model,
guide_delta,
optax.adabelief(0.01),
TraceMeanField_ELBO()
)
rng_key, rng_key_ = jax.random.split(rng_key)
svi_delta_result = svi_delta.run(
rng_key_,
10000,
data,
progress_bar=True
)
# Extend MAP to dense MultivariateNormal
base_map_values = guide_delta.median(svi_delta_result.params)
guide_normal = autoguide.AutoMultivariateNormal(
reg_model,
init_loc_fn=infer.init_to_value(values=base_map_values)
)
svi_normal = SVI(
model,
guide_normal,
optax.adabelief(0.01),
TraceMeanField_ELBO(num_particles=50)
)
rng_key, rng_key_ = jax.random.split(rng_key)
svi_normal_result = svi_normal.run(
rng_key_,
1500,
data,
progress_bar=True
)
# the full cov matrix
svi_normal_cov = svi_normal_result.params['auto_scale_tril'].dot(
svi_normal_result.params['auto_scale_tril'].T
)
map_values_normal = guide_normal.median(svi_normal_result.params)
#HMC sample
kernel = NUTS(
model,
init_strategy=init_to_value(values=map_values_normal),
dense_mass=True,
inverse_mass_matrix=svi_normal_cov,
adapt_mass_matrix=False,
target_accept_prob=0.8
)
mcmc = MCMC(
kernel,
num_warmup=1000,
num_samples=2000,
num_chains=6,
progress_bar=True,
)
rng_key, rng_key_ = jax.random.split(rng_key)
mcmc.run(rng_key_, data)
I also tried using NeuTraReparam
with the learned MVN guild, but this resulted in slower HMC steps and worse r_hat values out the other end.