Numpyro.render_model - ValueError: Too few leaves for PyTreeDef

Hi all,

I would like to render either model below, but the ValueError is a little cryptic to me. I thought the error was somehow related to my model definition, but both attempts give the same error from PyTreeDef with a different number of expected leaves.

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

guess = 0.7

def mystery_extend(guess):
    weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
    hair_cpt = jnp.array([[0.5, 0.5], [0.95, 0.05]])
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]))
    hair = numpyro.sample("hair", dist.Categorical(hair_cpt[murderer]))
    return murderer, weapon, hair

def mystery_plate(guess):
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    evidence = jnp.array([jnp.array([[0.9, 0.1], [0.2, 0.8]]), jnp.array([[0.5, 0.5], [0.95, 0.05]]), ])
    size = len(evidence)
    with numpyro.plate(f'i=1..{size}', size=size):
        obs = numpyro.sample("evidence", dist.Categorical(evidence[murderer]))
    return murderer, obs

using mystery_extend in numpyro.render_model:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_78015/3475709414.py in <module>
     21     return murderer, obs
     22 
---> 23 numpyro.render_model(mystery_extend, (guess,), render_distributions=True)

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/render.py in render_model(model, model_args, model_kwargs, filename, render_distributions, num_tries)
    310     :param int num_tries: Times to trace model to detect discrete -> continuous dependency.
    311     """
--> 312     relations = get_model_relations(
    313         model, model_args=model_args, model_kwargs=model_kwargs, num_tries=num_tries
    314     )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/render.py in get_model_relations(model, model_args, model_kwargs, num_tries)
    118         and not site["fn"].is_discrete
    119     }
--> 120     log_prob_grads = jax.jacobian(get_log_probs)(samples)
    121     sample_deps = {}
    122     for name, grads in log_prob_grads.items():

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/jax/_src/api.py in jacfun(*args, **kwargs)
   1067     example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
   1068     jac = tree_map(partial(_unravel_array_into_pytree, y, 0), jac)
-> 1069     return tree_transpose(tree_structure(example_args), tree_structure(y), jac)
   1070 
   1071   return jacfun

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/jax/_src/tree_util.py in tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)
    189   transposed_lol = zip(*lol)
    190   subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
--> 191   return tree_unflatten(inner_treedef, subtrees)
    192 
    193 # TODO(mattjj): remove the Python-side registry when the C++-side registry is

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/jax/_src/tree_util.py in tree_unflatten(treedef, leaves)
     59     described by ``treedef``.
     60   """
---> 61   return treedef.unflatten(leaves)
     62 
     63 def tree_leaves(tree):

ValueError: Too few leaves for PyTreeDef; expected 3, got 0

using mystery_plate in numpyro.render_model:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_78015/2897400285.py in <module>
     21     return murderer, obs
     22 
---> 23 numpyro.render_model(mystery_plate, (guess,), render_distributions=True)

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/render.py in render_model(model, model_args, model_kwargs, filename, render_distributions, num_tries)
    310     :param int num_tries: Times to trace model to detect discrete -> continuous dependency.
    311     """
--> 312     relations = get_model_relations(
    313         model, model_args=model_args, model_kwargs=model_kwargs, num_tries=num_tries
    314     )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/render.py in get_model_relations(model, model_args, model_kwargs, num_tries)
    118         and not site["fn"].is_discrete
    119     }
--> 120     log_prob_grads = jax.jacobian(get_log_probs)(samples)
    121     sample_deps = {}
    122     for name, grads in log_prob_grads.items():

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/jax/_src/api.py in jacfun(*args, **kwargs)
   1067     example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
   1068     jac = tree_map(partial(_unravel_array_into_pytree, y, 0), jac)
-> 1069     return tree_transpose(tree_structure(example_args), tree_structure(y), jac)
   1070 
   1071   return jacfun

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/jax/_src/tree_util.py in tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)
    189   transposed_lol = zip(*lol)
    190   subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
--> 191   return tree_unflatten(inner_treedef, subtrees)
    192 
    193 # TODO(mattjj): remove the Python-side registry when the C++-side registry is

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/jax/_src/tree_util.py in tree_unflatten(treedef, leaves)
     59     described by ``treedef``.
     60   """
---> 61   return treedef.unflatten(leaves)
     62 
     63 def tree_leaves(tree):

ValueError: Too few leaves for PyTreeDef; expected 2, got 0

I think you raised an edge case: jax.jacobian is not smart enough to take gradient with empty input. To resolve the issue, you can try

-    log_prob_grads = jax.jacobian(get_log_probs)(samples)
+    if samples:
+        log_prob_grads = jax.jacobian(get_log_probs)(samples)
+    else:
+        log_prob_grads = {k: {} for k in get_log_probs(samples)}

Do you want to submit the fix?

Cool, thank you @fehiepsi! I could, but I don’t know how to submit the fix… :sweat_smile:

Would it be just Issues or a PR?

Yeah, an issue is helpful (to remind us to address the issue). Or you can submit a PR with the above fix to this line and a regression test to this file (using the above model/code).