With pyro, I can
import pyro
pyro.distributions.Normal(0.0, 1.0).sample()
However, with numpyro
import numpyro
numpyro.distributions.Normal(0.0, 1.0).sample()
TypeError: sample() missing 1 required positional argument: 'key'
According to the documentation in numpyro, key
is optional. Any workarounds besides explicitly passing a key?