Sampling from distribution without RNG Key

With pyro, I can

import pyro
pyro.distributions.Normal(0.0, 1.0).sample()

However, with numpyro

import numpyro
numpyro.distributions.Normal(0.0, 1.0).sample()
TypeError: sample() missing 1 required positional argument: 'key'

According to the documentation in numpyro, key is optional. Any workarounds besides explicitly passing a key?

It is required. We didn’t support global random states though arguably, it will be both convenient and easy-to-make-mistakes. I think you can also sample with

with numpyro.handlers.seed(rng_seed=0):
    a = numpyro.sample("a", dist.Normal(0, 1))
1 Like