Hi all,
I would like to render either model below, but the ValueError
is a little cryptic to me. I thought the error was somehow related to my model definition, but both attempts give the same error from PyTreeDef
with a different number of expected leaves.
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
guess = 0.7
def mystery_extend(guess):
weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
hair_cpt = jnp.array([[0.5, 0.5], [0.95, 0.05]])
murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]))
hair = numpyro.sample("hair", dist.Categorical(hair_cpt[murderer]))
return murderer, weapon, hair
def mystery_plate(guess):
murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
evidence = jnp.array([jnp.array([[0.9, 0.1], [0.2, 0.8]]), jnp.array([[0.5, 0.5], [0.95, 0.05]]), ])
size = len(evidence)
with numpyro.plate(f'i=1..{size}', size=size):
obs = numpyro.sample("evidence", dist.Categorical(evidence[murderer]))
return murderer, obs
using mystery_extend
in numpyro.render_model
:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_78015/3475709414.py in <module>
21 return murderer, obs
22
---> 23 numpyro.render_model(mystery_extend, (guess,), render_distributions=True)
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/render.py in render_model(model, model_args, model_kwargs, filename, render_distributions, num_tries)
310 :param int num_tries: Times to trace model to detect discrete -> continuous dependency.
311 """
--> 312 relations = get_model_relations(
313 model, model_args=model_args, model_kwargs=model_kwargs, num_tries=num_tries
314 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/render.py in get_model_relations(model, model_args, model_kwargs, num_tries)
118 and not site["fn"].is_discrete
119 }
--> 120 log_prob_grads = jax.jacobian(get_log_probs)(samples)
121 sample_deps = {}
122 for name, grads in log_prob_grads.items():
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/jax/_src/api.py in jacfun(*args, **kwargs)
1067 example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
1068 jac = tree_map(partial(_unravel_array_into_pytree, y, 0), jac)
-> 1069 return tree_transpose(tree_structure(example_args), tree_structure(y), jac)
1070
1071 return jacfun
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/jax/_src/tree_util.py in tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)
189 transposed_lol = zip(*lol)
190 subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
--> 191 return tree_unflatten(inner_treedef, subtrees)
192
193 # TODO(mattjj): remove the Python-side registry when the C++-side registry is
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/jax/_src/tree_util.py in tree_unflatten(treedef, leaves)
59 described by ``treedef``.
60 """
---> 61 return treedef.unflatten(leaves)
62
63 def tree_leaves(tree):
ValueError: Too few leaves for PyTreeDef; expected 3, got 0
using mystery_plate
in numpyro.render_model
:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_78015/2897400285.py in <module>
21 return murderer, obs
22
---> 23 numpyro.render_model(mystery_plate, (guess,), render_distributions=True)
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/render.py in render_model(model, model_args, model_kwargs, filename, render_distributions, num_tries)
310 :param int num_tries: Times to trace model to detect discrete -> continuous dependency.
311 """
--> 312 relations = get_model_relations(
313 model, model_args=model_args, model_kwargs=model_kwargs, num_tries=num_tries
314 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/render.py in get_model_relations(model, model_args, model_kwargs, num_tries)
118 and not site["fn"].is_discrete
119 }
--> 120 log_prob_grads = jax.jacobian(get_log_probs)(samples)
121 sample_deps = {}
122 for name, grads in log_prob_grads.items():
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/jax/_src/api.py in jacfun(*args, **kwargs)
1067 example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
1068 jac = tree_map(partial(_unravel_array_into_pytree, y, 0), jac)
-> 1069 return tree_transpose(tree_structure(example_args), tree_structure(y), jac)
1070
1071 return jacfun
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/jax/_src/tree_util.py in tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)
189 transposed_lol = zip(*lol)
190 subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
--> 191 return tree_unflatten(inner_treedef, subtrees)
192
193 # TODO(mattjj): remove the Python-side registry when the C++-side registry is
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/jax/_src/tree_util.py in tree_unflatten(treedef, leaves)
59 described by ``treedef``.
60 """
---> 61 return treedef.unflatten(leaves)
62
63 def tree_leaves(tree):
ValueError: Too few leaves for PyTreeDef; expected 2, got 0