I have a model where I use the Hill function from biology, which has 2 parameters.
The below runs with no errors but the rhat is astronomical suggesting something is wrong.
x has a lower bound of 0. I think I have fixed div by 0 issues.
@jit
def jhill(x, ec50, slope):
return jnp.where(x == 0, 0, 1. / (1 + jnp.power(x / ec50, -slope)))
def model2(x=None, y=None):
a = numpyro.sample('a', dist.Normal(0, 10))
b = numpyro.sample('b', dist.Uniform(0, 10))
slope = numpyro.sample('slope', dist.Uniform(.5, 7))
ec50 = numpyro.sample('ec50', dist.Gamma(1, 1))
sigma = numpyro.sample('sigma', dist.Gamma(1, 1))
numpyro.sample('obs', dist.Normal(a + b * jhill(x, ec50, slope), sigma), obs=y)
nuts_kernel = NUTS(model2)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=num_samples, num_chains=num_chains)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, x=x, y=y)