AutoMultivariateNormal guide and funsor

Hi there,

I’m trying to use an AutoMultivariateNormal guide for and SVI fit of a model with discrete latent variables in numpyro. AutoNormal works fine, but when I use AutoMultivariateNormal the SVI fitting throws an error. I would like to use AutoMultivariateNormal as AutoNormal underestimates the parameter variances in my model, compared to NUTS. Is this a bug in numpyro.funsor or have I messed something up?

Minimal example; model definition:

def model():
  
  K = 5
  N = 10

  with numpyro.plate("alpha_plate", K):
    alpha = numpyro.sample("alpha", dist.Uniform(0.3, 10.0))
  
  
  prob = numpyro.sample("prob", dist.Dirichlet(jnp.ones(K) * 0.5))

  with numpyro.plate("prob_plate", N):
    assignment =  numpyro.sample(
        'assignment',
        dist.CategoricalProbs(prob)
    )

Working SVI with AutoNormal

rng = jax.random.PRNGKey(0)
guide_N = numpyro.infer.autoguide.AutoNormal(
    numpyro.handlers.block(
        numpyro.handlers.seed(
            model,
            rng
        ),
        hide=["assignment"]
    )
)

svi = numpyro.infer.SVI(
    config_enumerate(model), 
    guide_N, 
    numpyro.optim.Adam(0.01),
    numpyro.infer.TraceEnum_ELBO(num_particles=1)
)
svi_result = svi.run(rng, 5000) 
svi_result.params

Erroring code with AutoMultivariateNormal

rng = jax.random.PRNGKey(0)
guide_AMN = numpyro.infer.autoguide.AutoMultivariateNormal(
    numpyro.handlers.block(
        numpyro.handlers.seed(
            model,
            rng
        ),
        hide=["assignment"]
    )
)

svi = numpyro.infer.SVI(
    config_enumerate(model), 
    guide_AMN, 
    numpyro.optim.Adam(0.01),
    numpyro.infer.TraceEnum_ELBO(num_particles=1)
)
svi_result = svi.run(rng, 5000) 

The error thrown is:

---------------------------------------------------------------------------

KeyError                                  Traceback (most recent call last)

<ipython-input-6-62d718ae30b0> in <cell line: 0>()
     18     elbo
     19 )
---> 20 svi_result = svi.run(rng, 5000)
     21 svi_result.params

21 frames

    [... skipping hidden 11 frame]

    [... skipping hidden 8 frame]

/usr/local/lib/python3.11/dist-packages/numpyro/contrib/funsor/enum_messenger.py in _pyro_post_to_data(self, msg)
    427             for name in msg["args"][0].inputs:
    428                 self._saved_globals += (
--> 429                     (name, _DIM_STACK.global_frame.name_to_dim[name]),
    430                 )
    431 

KeyError: 'alpha_plate'

Full code on a google colab in case that helps: Google Colab

Thanks!

As far as I know, auto mvn guide does not work with enumeration. You can build a custom guide though.

Thanks @fehiepsi . Are there any plans to make the MVN autoguide work with enumeration?