I am trying to use numpyro to fit a large model with ~1500 latent parameters with a dense covariance matrix. To try and help NUTS along I have broken the problem into several stages:

- Use SVI with
`AutoDelta`

to find a decent MAP solution (works well) - Use this MAP solution as an initial location for
`AutoMultivariateNormal`

to estimate the dense covariance matrix - Use this covariance matrix as the
`inverse_mass_matrix`

for NUTS so it does not have to learn it in warmup. As the true posterior is not a MVN, use NUTS to get a more accurate sample.

My questions:

- Are the
`auto_scale_tril`

values from the`AutoMultivariateNormal`

in the same order as NUTS’s`inverse_mass_matrix`

when applied to the same model? - Are the scales the same?
- Do I need to somehow get the “unscaled” cov matrix out of
`AutoMultivariateNormal`

before trying to pass it forward?

When I tried this I ended up with all draws divergent.

Pseudo code:

```
def model(data):
# large model with ~1500 parameters
# (estimating pixel value in a 39x39 image and ~20 model parameters on top of that
...
# Find MAP value
guide_delta = AutoDelta(model, init_loc_fn=infer.init_to_sample())
svi_delta = SVI(
reg_model,
guide_delta,
optax.adabelief(0.01),
TraceMeanField_ELBO()
)
rng_key, rng_key_ = jax.random.split(rng_key)
svi_delta_result = svi_delta.run(
rng_key_,
10000,
data,
progress_bar=True
)
# Extend MAP to dense MultivariateNormal
base_map_values = guide_delta.median(svi_delta_result.params)
guide_normal = autoguide.AutoMultivariateNormal(
reg_model,
init_loc_fn=infer.init_to_value(values=base_map_values)
)
svi_normal = SVI(
model,
guide_normal,
optax.adabelief(0.01),
TraceMeanField_ELBO(num_particles=50)
)
rng_key, rng_key_ = jax.random.split(rng_key)
svi_normal_result = svi_normal.run(
rng_key_,
1500,
data,
progress_bar=True
)
# the full cov matrix
svi_normal_cov = svi_normal_result.params['auto_scale_tril'].dot(
svi_normal_result.params['auto_scale_tril'].T
)
map_values_normal = guide_normal.median(svi_normal_result.params)
#HMC sample
kernel = NUTS(
model,
init_strategy=init_to_value(values=map_values_normal),
dense_mass=True,
inverse_mass_matrix=svi_normal_cov,
adapt_mass_matrix=False,
target_accept_prob=0.8
)
mcmc = MCMC(
kernel,
num_warmup=1000,
num_samples=2000,
num_chains=6,
progress_bar=True,
)
rng_key, rng_key_ = jax.random.split(rng_key)
mcmc.run(rng_key_, data)
```

I also tried using `NeuTraReparam`

with the learned MVN guild, but this resulted in slower HMC steps and worse r_hat values out the other end.