Obtaining the unconstrained MAP parameters from the Laplace guide

I’m fitting the Laplace approximation to a model and I would like to find the unconstrained MAP parameters given a dictionary containing the constrained MAP parameters. One way to do this is the following:

guide = AutoLaplaceApproximation(model, init_loc_fn=init_to_value(values=map_params_constrained))
svi = SVI(model, guide, numpyro.optim.Minimize(method='BFGS'), Trace_ELBO())
state = svi.init(
    random.PRNGKey(0), **model_kwargs,
map_params_unconstrained = svi.get_params(state)

svi.get_params(state) returns a dictionary with a key auto_loc containing an array with the unconstrained parameters. The problem with this approach is that if I jax.jit this function, the order of the parameters in map_params_unconstrained['auto_loc'] gets permuted for some reason!

Is there a way to get a dictionary of the unconstrained parameters instead of an array?

I guess you can try


(it would be helpful if you can provide some reproducible code, so that we can see what you mean by “permuted”)

That works both within and outside jax.jit, thank you so much!

I may post a reproducible example later if I have time. In general I found that in rare cases jit changes the behaviour of jax functions. It may be something that’s highly dependent on the hardware and versions of jax and jaxlib.