I’m fitting the Laplace approximation to a model and I would like to find the unconstrained MAP parameters given a dictionary containing the constrained MAP parameters. One way to do this is the following:
guide = AutoLaplaceApproximation(model, init_loc_fn=init_to_value(values=map_params_constrained))
svi = SVI(model, guide, numpyro.optim.Minimize(method='BFGS'), Trace_ELBO())
state = svi.init(
random.PRNGKey(0), **model_kwargs,
)
map_params_unconstrained = svi.get_params(state)
svi.get_params(state)
returns a dictionary with a key auto_loc
containing an array with the unconstrained parameters. The problem with this approach is that if I jax.jit
this function, the order of the parameters in map_params_unconstrained['auto_loc']
gets permuted for some reason!
Is there a way to get a dictionary of the unconstrained parameters instead of an array?