Hi there,
I’m trying to use an AutoMultivariateNormal
guide for and SVI fit of a model with discrete latent variables in numpyro. AutoNormal
works fine, but when I use AutoMultivariateNormal
the SVI fitting throws an error. I would like to use AutoMultivariateNormal
as AutoNormal
underestimates the parameter variances in my model, compared to NUTS. Is this a bug in numpyro.funsor
or have I messed something up?
Minimal example; model definition:
def model():
K = 5
N = 10
with numpyro.plate("alpha_plate", K):
alpha = numpyro.sample("alpha", dist.Uniform(0.3, 10.0))
prob = numpyro.sample("prob", dist.Dirichlet(jnp.ones(K) * 0.5))
with numpyro.plate("prob_plate", N):
assignment = numpyro.sample(
'assignment',
dist.CategoricalProbs(prob)
)
Working SVI with AutoNormal
rng = jax.random.PRNGKey(0)
guide_N = numpyro.infer.autoguide.AutoNormal(
numpyro.handlers.block(
numpyro.handlers.seed(
model,
rng
),
hide=["assignment"]
)
)
svi = numpyro.infer.SVI(
config_enumerate(model),
guide_N,
numpyro.optim.Adam(0.01),
numpyro.infer.TraceEnum_ELBO(num_particles=1)
)
svi_result = svi.run(rng, 5000)
svi_result.params
Erroring code with AutoMultivariateNormal
rng = jax.random.PRNGKey(0)
guide_AMN = numpyro.infer.autoguide.AutoMultivariateNormal(
numpyro.handlers.block(
numpyro.handlers.seed(
model,
rng
),
hide=["assignment"]
)
)
svi = numpyro.infer.SVI(
config_enumerate(model),
guide_AMN,
numpyro.optim.Adam(0.01),
numpyro.infer.TraceEnum_ELBO(num_particles=1)
)
svi_result = svi.run(rng, 5000)
The error thrown is:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-6-62d718ae30b0> in <cell line: 0>()
18 elbo
19 )
---> 20 svi_result = svi.run(rng, 5000)
21 svi_result.params
21 frames
[... skipping hidden 11 frame]
[... skipping hidden 8 frame]
/usr/local/lib/python3.11/dist-packages/numpyro/contrib/funsor/enum_messenger.py in _pyro_post_to_data(self, msg)
427 for name in msg["args"][0].inputs:
428 self._saved_globals += (
--> 429 (name, _DIM_STACK.global_frame.name_to_dim[name]),
430 )
431
KeyError: 'alpha_plate'
Full code on a google colab in case that helps: Google Colab
Thanks!