Recommended way of doing random initialization for SVI

For some models when using SVI a proper random initialization is necessary as also explained in Pyro’s docs for Gaussian Mixture Models. Since JAX and thus NumPyro explicitly handles the random state, it’s not super trivial to do a random initialization in the guide function as the guide has to reflect the parameters of the model. Several ways appear to be possible to accomplish random initialization:

  1. closure: The actual guide function could be nested within an outer function that takes a random state and returns the nested guide such that the guide function can access a PRNGKey object from the closure to do the initialization.
  2. numpy: Using numpy and its global random state to initialize arrays and converting them to JAX.

Those are the two obvious ways that come to my mind and I wonder what the recommended way is in the NumPyro community?

1 Like

Hi @FlorianWilhelm, if you use AutoGuide, then you can get the random key from site["kwargs"]["rng_key"] as in init_to_median. A custom init_loc_fn will be similar to the one in Gaussian mixture models example.

If you want to specify random initial values for parameters of a custom guide, you can do something like

numpyro.param("a", lambda rng_key: random.normal(rng_key))
3 Likes