I am trying to figure out what order the AutoDiagonalNormal
guide flattens its variables in the auto_loc
and auto_scale
arrays, is there an easy way to do this? I initially thought it would be jax.flatten_util.ravel_pytree
but this does not seem to be the case.
A toy example:
def model_funnel():
y = numpyro.sample("y", dist.Normal(0, 3))
numpyro.sample("x", dist.Normal(jnp.zeros(1), jnp.exp(y / 2)))
guide_funnel = numpyro.infer.autoguide.AutoDiagonalNormal(model_funnel)
optim = numpyro.optim.Adam(step_size=1e-4)
svi = numpyro.infer.SVI(model_funnel, guide_funnel, optim, loss=numpyro.infer.Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 2000)
auto_loc = svi_result.params['auto_loc']
unpacked_loc = guide_funnel.median(params=svi_result.params)
jnp.allclose(auto_loc, jax.flatten_util.ravel_pytree(unpacked_loc)[0])
From this example I can see in this case x
and y
are in the opposite order than given by ravel_pytree
.
This issue can also be seen with the _unpack_latent
and _unpack_latent._inverse
functions not matching for the guide either.
values = jnp.array([1.0, 2.0])
unpack = guide_funnel._unpack_latent(values)
pack = guide_funnel._unpack_latent._inverse(unpack)
jnp.allclose(pack, values)
In both cases the allclose
will fail. Is there a clean way to figure out the order?