Numpyro render_model failure?

Hello,
Let me write a working snippet before showing an example with a failure. So

import jax.numpy as jnp
import numpyro  # '0.9.1'

def diag_mean_fn(x, params):
    R0 = params["R0"]
    v  = params["v"]
    k  = params["k"]
    tau =  params["tau"]
    return R0 + v*x - k*(1.-jnp.exp(-x/tau))

def diag_model(t,Robs,
          R0_min=10.,v_min = 0.5,k_min = 1., tau_min=0.1,
          R0_max=50.,v_max = 3.5,k_max = 10., tau_max=5.0,
          sigma= 1.0):
    # priors
    R0 = numpyro.sample("R0", dist.Uniform(R0_min,R0_max))
    v  = numpyro.sample("v", dist.Uniform(v_min,v_max))
    k  = numpyro.sample("k", dist.Uniform(k_min,k_max))
    tau= numpyro.sample("tau", dist.Uniform(tau_min,tau_max))

    mu = diag_mean_fn(t,{"R0":R0, "v":v, "k":k, "tau":tau})
    with numpyro.plate("obs", t.shape[0]):
        numpyro.sample('R', dist.Normal(mu, sigma), obs=Robs)

numpyro.render_model(diag_model, model_args=(jnp.array([0.]),jnp.array([1.])))

this gives nicely this schema
image

Now, change the diag_mean_fn to

def diag_mean_fn(x, params):
    R0 = params["R0"]
    v  = params["v"]
    k  = params["k"]
    tau =  params["tau"]
    return jnp.piecewise(
        x, [x < 0, x >= 0],
        [lambda x: R0 + v*x, 
         lambda x: R0 + v*x - k*(1.-jnp.exp(-x/tau))
        ])

this leads to a crash of
…
File /jax/_src/numpy/lax_numpy.py:4345, in _piecewise.._call..(x)
4344 def _call(f):
β†’ 4345 return lambda x: f(x, *args, **kw).astype(dtype)

Input In [112], in diag_mean_fn..(x)
4 k = params[β€œk”]
5 tau = params[β€œtau”]
6 return jnp.piecewise(
7 x, [x < 0, x >= 0],
----> 8 [lambda x: R0 + vx,
9 lambda x: R0 + v
x - k*(1.-jnp.exp(-x/tau))
10 ])

[... skipping hidden 1 frame]

File /jax/_src/numpy/lax_numpy.py:4502, in _defer_to_unrecognized_arg..deferring_binary_op(self, other)
4500 if not isinstance(other, _accepted_binop_types):
4501 return NotImplemented
β†’ 4502 return binary_op(self, other)

[... skipping hidden 7 frame]

File /jax/_src/numpy/ufuncs.py:87, in _maybe_bool_binop..fn(x1, x2)
85 def fn(x1, x2):
86 x1, x2 = promote_args(numpy_fn.name, x1, x2)
β€”> 87 return lax_fn(x1, x2) if x1.dtype != np.bool
else bool_lax_fn(x1, x2)

[... skipping hidden 7 frame]

File /jax/core.py:1693, in join_named_shapes(*named_shapes)
1691 for name, size in named_shape.items():
1692 if result.setdefault(name, size) != size:
β†’ 1693 raise TypeError(
1694 f"Axis name {name} used with inconsistent sizes: {result[name]} != {size}")
1695 return result

TypeError: Axis name _provenance used with inconsistent sizes: frozenset({β€˜R0’}) != frozenset({β€˜v’})


Have you an idea how I can solve this problem? Thanks.

Side question does someone has a nice link to explain simply what is a `plate` in the context of statistical modelling?

I think the current provenance approach does not work with jax control flow. I think you can use

return jnp.where(x < 0, R0 + v*x, R0 + v*x - k*(1.-jnp.exp(-x/tau)))

instead. Here is a reference for plate: Plate notation - Wikipedia :slight_smile:

Ho. Nice indeed.