For some models when using SVI a proper random initialization is necessary as also explained in Pyro’s docs for Gaussian Mixture Models. Since JAX and thus NumPyro explicitly handles the random state, it’s not super trivial to do a random initialization in the guide function as the guide has to reflect the parameters of the model. Several ways appear to be possible to accomplish random initialization:
-
closure: The actual guide function could be nested within an outer function that takes a random state and returns the nested guide such that the guide function can access a
PRNGKey
object from the closure to do the initialization. - numpy: Using numpy and its global random state to initialize arrays and converting them to JAX.
Those are the two obvious ways that come to my mind and I wonder what the recommended way is in the NumPyro community?