This is probably more of a Jax problem but I’m only encountering the issue in the context of numpyro sampling and this group has been so helpful I thought I’d try here too.
I’m fumbling my way through implementing some custom jvp for the regularized incomplete beta function by attempting to port over some code from the Stan project (this and this). I’m trying to model a survey where people were asked to report probabilities. The underlying cognitive model involves a Beta distribution, but the responses are often rounded. So the efforts to implement the Beta CDF relate to the rounding.
The functions use a while loop that I am trying to implement with jax.lax.while_loop()
. My implementations seem to work in some contexts but not others. When I try to run this inside a numpyro MCMC sampling context, I get the following error:
body_fun output and input must have identical types, got
('ShapedArray(float32[11])', 'DIFFERENT ShapedArray(float32[11]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float32[11]) vs. ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(int32[], weak_type=True)', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[11])').
I gather I am somehow transforming the types inside my body_fun
(_betainc_dda_while
below) but I can’t seem to figure out where. Hoping for some help debugging or even reading this error message, I can’t see where the issue is arising.
here’s my code:
import jax
import jax.nump as jnp
from jax import custom_jvp
from jax.scipy.special import gammaln
from jax.numpy import log1p, log, exp
@custom_jvp
def _betainc(a,b,x):
return jax.scipy.special.betainc(a,b,x) # assign my own
def _betainc_dda_while(args):
summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, digamma_a, x = args
sum_numer = sum_numer + (digamma_ab - digamma_a) * summand
sum_denom += summand
summand = summand*(1 + (a_plus_b) / k) * (1 + k) / (1 + a_plus_1 / k)
digamma_ab += 1./(a_plus_b + k)
digamma_a += 1./(a_plus_1 + k)
k += 1
summand = summand * (x / k)
args = (summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, digamma_a, x)
return args
## define partial derivative wrt a
def _betainc_dda(a_dot, primal_out, a, b, x):
a = jnp.asarray(a)
b = jnp.asarray(a)
x = jnp.asarray(x)
digamma_a = jax.scipy.special.digamma(a)
digamma_ab = jax.scipy.special.digamma(a+b)
threshold = 1e-10
a_plus_b = a + b
a_plus_1 = a + 1.
digamma_a = digamma_a + 1./a
prefactor = jnp.power(a_plus_1 / a_plus_b, 3)
sum_numer = (digamma_ab - digamma_a) * prefactor
sum_denom = prefactor
summand = prefactor * x * a_plus_b / a_plus_1
k = 1
digamma_ab = digamma_ab + 1./a_plus_b
digamma_a = digamma_a + 1./a_plus_1
### ----- 6/9/22, 4:03 PM something in the while loop changing the types?
out = jax.lax.while_loop(
lambda args: jnp.any(jnp.abs(args[0]) > 1e-10),
_betainc_dda_while,
(summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, digamma_a, x)
)
summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, digamma_a, x = out
return _betainc(a, b, x) * (log(x) + sum_numer / sum_denom)*a_dot
def _betainc_ddb_while(args):
summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, x = args
sum_numer += digamma_ab * summand
sum_denom += summand
summand = summand*(1 + (a_plus_b) / k) * (1 + k) / (1 + a_plus_1 / k)
digamma_ab += 1./(a_plus_b + k)
k +=1
summand = summand * x / k
args = (summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, x)
return args
## define partial derivative wrt b
def _betainc_ddb(b_dot, primal_out, a, b, x):
digamma_b = jax.scipy.special.digamma(b)
digamma_ab = jax.scipy.special.digamma(a+b)
threshold = 1e-10
a_plus_b = a + b
a_plus_1 = a + 1.
prefactor = jnp.power(a_plus_1 / a_plus_b, 3)
sum_numer = digamma_ab * prefactor
sum_denom = prefactor
summand = prefactor * x * a_plus_b / a_plus_1
k = 1
digamma_ab = digamma_ab + 1./a_plus_b
out = jax.lax.while_loop(
lambda args: jnp.any(jnp.abs(args[0]) > 1e-10),
_betainc_ddb_while,
(summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, x)
)
summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, x = out
return _betainc(a, b, x) * (log(1 - x) - digamma_b + sum_numer / sum_denom)*b_dot
def betainc_gradx(g, primal_out, a, b, x):
lbeta = gammaln(a) + gammaln(b) - gammaln(a + b)
partial_x = exp((b - 1) * log1p(-x) +
(a - 1) * log(x) - lbeta)
return partial_x * g
_betainc.defjvps(_betainc_dda, _betainc_ddb, betainc_gradx)
## these work
print(grad(_betainc, 0)(2., 1., .6))
print(grad(_betainc, 1)(2., 1., .6))
print(grad(_betainc, 2)(2., 1., .6))
Here’s the numpyro data simulation and model I’m using for testing:
from jax.random import PRNGKey
X_raw = dist.Beta(.75*10, (1-.75)*10).sample(PRNGKey(10), (100,))
x_round = round(X_raw*10)
def f(mu, k):
a = mu*k
b = (1.-mu)*k
responses = jnp.linspace(0,10, num=11)
lower = jnp.clip((responses/10.) - .05, 0., 1.)
upper = jnp.clip((responses/10.) + .05, 0., 1.)
prob_resps = _betainc(a, b, upper) - _betainc(a, b, lower)
return(prob_resps)
def mymodel_round(x=None):
mu = numpyro.sample("mu", dist.Beta(1,1)) # noise parameter
k = numpyro.sample("k", dist.HalfCauchy(10)) # noise parameter
resp_probs = f(mu,k)
with numpyro.plate("data", x.shape[0]):
xhat = numpyro.sample("xhat", dist.Categorical(probs=resp_probs), obs=x)
And this produces the error
kernel = NUTS(mymodel_round, target_accept_prob=.80)
mcmc = MCMC(kernel,
num_warmup=1_000,
num_samples=1_000,
num_chains=1)
mcmc.run(random.PRNGKey(0), x_round)
Appreciate any help!