Hello,
Let me write a working snippet before showing an example with a failure. So
import jax.numpy as jnp
import numpyro # '0.9.1'
def diag_mean_fn(x, params):
R0 = params["R0"]
v = params["v"]
k = params["k"]
tau = params["tau"]
return R0 + v*x - k*(1.-jnp.exp(-x/tau))
def diag_model(t,Robs,
R0_min=10.,v_min = 0.5,k_min = 1., tau_min=0.1,
R0_max=50.,v_max = 3.5,k_max = 10., tau_max=5.0,
sigma= 1.0):
# priors
R0 = numpyro.sample("R0", dist.Uniform(R0_min,R0_max))
v = numpyro.sample("v", dist.Uniform(v_min,v_max))
k = numpyro.sample("k", dist.Uniform(k_min,k_max))
tau= numpyro.sample("tau", dist.Uniform(tau_min,tau_max))
mu = diag_mean_fn(t,{"R0":R0, "v":v, "k":k, "tau":tau})
with numpyro.plate("obs", t.shape[0]):
numpyro.sample('R', dist.Normal(mu, sigma), obs=Robs)
numpyro.render_model(diag_model, model_args=(jnp.array([0.]),jnp.array([1.])))
this gives nicely this schema
Now, change the diag_mean_fn
to
def diag_mean_fn(x, params):
R0 = params["R0"]
v = params["v"]
k = params["k"]
tau = params["tau"]
return jnp.piecewise(
x, [x < 0, x >= 0],
[lambda x: R0 + v*x,
lambda x: R0 + v*x - k*(1.-jnp.exp(-x/tau))
])
this leads to a crash of
β¦
File /jax/_src/numpy/lax_numpy.py:4345, in _piecewise.._call..(x)
4344 def _call(f):
β 4345 return lambda x: f(x, *args, **kw).astype(dtype)
Input In [112], in diag_mean_fn..(x)
4 k = params[βkβ]
5 tau = params[βtauβ]
6 return jnp.piecewise(
7 x, [x < 0, x >= 0],
----> 8 [lambda x: R0 + vx,
9 lambda x: R0 + vx - k*(1.-jnp.exp(-x/tau))
10 ])
[... skipping hidden 1 frame]
File /jax/_src/numpy/lax_numpy.py:4502, in _defer_to_unrecognized_arg..deferring_binary_op(self, other)
4500 if not isinstance(other, _accepted_binop_types):
4501 return NotImplemented
β 4502 return binary_op(self, other)
[... skipping hidden 7 frame]
File /jax/_src/numpy/ufuncs.py:87, in _maybe_bool_binop..fn(x1, x2)
85 def fn(x1, x2):
86 x1, x2 = promote_args(numpy_fn.name, x1, x2)
β> 87 return lax_fn(x1, x2) if x1.dtype != np.bool else bool_lax_fn(x1, x2)
[... skipping hidden 7 frame]
File /jax/core.py:1693, in join_named_shapes(*named_shapes)
1691 for name, size in named_shape.items():
1692 if result.setdefault(name, size) != size:
β 1693 raise TypeError(
1694 f"Axis name {name} used with inconsistent sizes: {result[name]} != {size}")
1695 return result
TypeError: Axis name _provenance used with inconsistent sizes: frozenset({βR0β}) != frozenset({βvβ})
Have you an idea how I can solve this problem? Thanks.
Side question does someone has a nice link to explain simply what is a `plate` in the context of statistical modelling?