Obtaining the unconstrained MAP parameters from the Laplace guide

I’m fitting the Laplace approximation to a model and I would like to find the unconstrained MAP parameters given a dictionary containing the constrained MAP parameters. One way to do this is the following:

guide = AutoLaplaceApproximation(model, init_loc_fn=init_to_value(values=map_params_constrained))
svi = SVI(model, guide, numpyro.optim.Minimize(method='BFGS'), Trace_ELBO())
state = svi.init(
    random.PRNGKey(0), **model_kwargs,
)
map_params_unconstrained = svi.get_params(state)

svi.get_params(state) returns a dictionary with a key auto_loc containing an array with the unconstrained parameters. The problem with this approach is that if I jax.jit this function, the order of the parameters in map_params_unconstrained['auto_loc'] gets permuted for some reason!

Is there a way to get a dictionary of the unconstrained parameters instead of an array?

I guess you can try

guide._unpack_latent(map_params_unconstrained['auto_loc'])

(it would be helpful if you can provide some reproducible code, so that we can see what you mean by “permuted”)

1 Like

That works both within and outside jax.jit, thank you so much!

I may post a reproducible example later if I have time. In general I found that in rare cases jit changes the behaviour of jax functions. It may be something that’s highly dependent on the hardware and versions of jax and jaxlib.