Using SVI MultivariateNormal to estimate NUTS inverse_mass_matrix

I am trying to use numpyro to fit a large model with ~1500 latent parameters with a dense covariance matrix. To try and help NUTS along I have broken the problem into several stages:

  1. Use SVI with AutoDelta to find a decent MAP solution (works well)
  2. Use this MAP solution as an initial location for AutoMultivariateNormal to estimate the dense covariance matrix
  3. Use this covariance matrix as the inverse_mass_matrix for NUTS so it does not have to learn it in warmup. As the true posterior is not a MVN, use NUTS to get a more accurate sample.

My questions:

  1. Are the auto_scale_tril values from the AutoMultivariateNormal in the same order as NUTS’s inverse_mass_matrix when applied to the same model?
  2. Are the scales the same?
  3. Do I need to somehow get the “unscaled” cov matrix out of AutoMultivariateNormal before trying to pass it forward?

When I tried this I ended up with all draws divergent.

Pseudo code:

def model(data):
    # large model with ~1500 parameters
    # (estimating pixel value in a 39x39 image and ~20 model parameters on top of that
    ...

# Find MAP value
guide_delta = AutoDelta(model, init_loc_fn=infer.init_to_sample())
svi_delta = SVI(
    reg_model,
    guide_delta,
    optax.adabelief(0.01),
    TraceMeanField_ELBO()
)
rng_key, rng_key_ = jax.random.split(rng_key)
svi_delta_result = svi_delta.run(
    rng_key_,
    10000,
    data,
    progress_bar=True
)

# Extend MAP to dense MultivariateNormal
base_map_values = guide_delta.median(svi_delta_result.params)
guide_normal = autoguide.AutoMultivariateNormal(
    reg_model,
    init_loc_fn=infer.init_to_value(values=base_map_values)
)
svi_normal = SVI(
    model,
    guide_normal,
    optax.adabelief(0.01),
    TraceMeanField_ELBO(num_particles=50)
)
rng_key, rng_key_ = jax.random.split(rng_key)
svi_normal_result = svi_normal.run(
    rng_key_,
    1500,
    data,
    progress_bar=True
)

# the full cov matrix
svi_normal_cov = svi_normal_result.params['auto_scale_tril'].dot(
    svi_normal_result.params['auto_scale_tril'].T
)

map_values_normal = guide_normal.median(svi_normal_result.params)

#HMC sample
kernel = NUTS(
    model,
    init_strategy=init_to_value(values=map_values_normal),
    dense_mass=True,
    inverse_mass_matrix=svi_normal_cov,
    adapt_mass_matrix=False,
    target_accept_prob=0.8
)
mcmc = MCMC(
    kernel,
    num_warmup=1000,
    num_samples=2000,
    num_chains=6,
    progress_bar=True,
)
rng_key, rng_key_ = jax.random.split(rng_key)
mcmc.run(rng_key_, data)

I also tried using NeuTraReparam with the learned MVN guild, but this resulted in slower HMC steps and worse r_hat values out the other end.

while it could conceivably work in some problems i’ve never had any success with an approach like this, basically because the variational covariance tends to be (perhaps grossly) underestimated

instead i recommend trying some of the suggestions here