AutoDiagonalNormal flatten order

I am trying to figure out what order the AutoDiagonalNormal guide flattens its variables in the auto_loc and auto_scale arrays, is there an easy way to do this? I initially thought it would be jax.flatten_util.ravel_pytree but this does not seem to be the case.

A toy example:

def model_funnel():
    y = numpyro.sample("y", dist.Normal(0, 3))
    numpyro.sample("x", dist.Normal(jnp.zeros(1), jnp.exp(y / 2)))

guide_funnel = numpyro.infer.autoguide.AutoDiagonalNormal(model_funnel)
optim = numpyro.optim.Adam(step_size=1e-4)
svi = numpyro.infer.SVI(model_funnel, guide_funnel, optim, loss=numpyro.infer.Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 2000)

auto_loc = svi_result.params['auto_loc']
unpacked_loc = guide_funnel.median(params=svi_result.params)

jnp.allclose(auto_loc, jax.flatten_util.ravel_pytree(unpacked_loc)[0])

From this example I can see in this case x and y are in the opposite order than given by ravel_pytree.

This issue can also be seen with the _unpack_latent and _unpack_latent._inverse functions not matching for the guide either.

values = jnp.array([1.0, 2.0])
unpack = guide_funnel._unpack_latent(values)
pack = guide_funnel._unpack_latent._inverse(unpack)
jnp.allclose(pack, values)

In both cases the allclose will fail. Is there a clean way to figure out the order?

Note: My end goal is to be able to evaluate the guideā€™s log_prob at various positions from an MC chain that are stored in the structured pytree of the model (already converted to unconstrained space).

This seems like a bug to me. Could you create a github issue for this?

I can see in this case x and y are in the opposite order

IIRC this is the order of variables in the model. So model_funnel will have y first and then x appear in the trace.