I want to be able to render_model
where a param
is initialized randomly (see minimal example below). However, I’m currently get an error.
Is there anything I might be overlooking? If not, I’m happy to take a shot at fixing it.
from numpyro import render_model, param, sample
from numpyro.distributions import Exponential
from numpyro.handlers import seed
def model_works():
rate = param('rate', 1.)
return sample('x', Exponential(rate))
def model_fail():
rate = param('rate', lambda rng_key: Exponential(1.).sample(rng_key))
return sample('x', Exponential(rate))
with seed(rng_seed=0):
# sanity check model forward
assert model_works().shape == ()
assert model_fail().shape == ()
render_model(
model=model_works,
model_args=(),
render_distributions=True,
render_params=True,
filename="model_works.png",
)
render_model( # Fails at infer.inspect.get_model_relations line 326
model=model_fail,
model_args=(),
render_distributions=True,
render_params=True,
filename="model_fail.png",
)