Model with function, debug

I have a model where I use the Hill function from biology, which has 2 parameters.

The below runs with no errors but the rhat is astronomical suggesting something is wrong.

x has a lower bound of 0. I think I have fixed div by 0 issues.

@jit
def jhill(x, ec50, slope):
  return jnp.where(x == 0, 0, 1. / (1 + jnp.power(x / ec50, -slope)))

def model2(x=None, y=None):
  a = numpyro.sample('a', dist.Normal(0, 10))
  b = numpyro.sample('b', dist.Uniform(0, 10))
  slope = numpyro.sample('slope', dist.Uniform(.5, 7))
  ec50 = numpyro.sample('ec50', dist.Gamma(1, 1))
  sigma = numpyro.sample('sigma', dist.Gamma(1, 1))
  numpyro.sample('obs', dist.Normal(a + b * jhill(x, ec50, slope), sigma), obs=y)

nuts_kernel = NUTS(model2)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=num_samples, num_chains=num_chains)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, x=x, y=y)

You might try clipping ec50 as this could lead to overflow

  ec50 = numpyro.sample('ec50', dist.Gamma(1, 1))
+ ec50 = jnp.clip(ec50, a_min=1e-3)  # play around with threshold

you might also trying using a non-gradient based algorithm. it looks like your model might have relatively extreme curvature

SA: kernel = SA(model)

Another option is to implement hill as

jax.scipy.special.expit(slope * jnp.log(jnp.clip(x, a_min=1e-7) / ec50))
1 Like

You could use JaxNS for this. It’s not gradient based and gives arbitrarily precise results. JaxNS is currently being integrated into numpyro (@fehiepsi), but you can just use JaxNS’s own probabilistic programming infrastructure if you want to do this directly with it. I encourage your to post your problem in the discussion page, and I’ll lend support.