Trouble with jax.lax.while_loop() in gradient

This is probably more of a Jax problem but I’m only encountering the issue in the context of numpyro sampling and this group has been so helpful I thought I’d try here too.

I’m fumbling my way through implementing some custom jvp for the regularized incomplete beta function by attempting to port over some code from the Stan project (this and this). I’m trying to model a survey where people were asked to report probabilities. The underlying cognitive model involves a Beta distribution, but the responses are often rounded. So the efforts to implement the Beta CDF relate to the rounding.

The functions use a while loop that I am trying to implement with jax.lax.while_loop(). My implementations seem to work in some contexts but not others. When I try to run this inside a numpyro MCMC sampling context, I get the following error:

body_fun output and input must have identical types, got
('ShapedArray(float32[11])', 'DIFFERENT ShapedArray(float32[11]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float32[11]) vs. ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(int32[], weak_type=True)', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[11])').

I gather I am somehow transforming the types inside my body_fun (_betainc_dda_while below) but I can’t seem to figure out where. Hoping for some help debugging or even reading this error message, I can’t see where the issue is arising.

here’s my code:

import jax
import jax.nump as jnp
from jax import custom_jvp
from jax.scipy.special import gammaln
from jax.numpy import log1p, log, exp

@custom_jvp
def _betainc(a,b,x):
    return jax.scipy.special.betainc(a,b,x) # assign my own



def _betainc_dda_while(args):
    summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, digamma_a, x = args
    
    sum_numer = sum_numer + (digamma_ab - digamma_a) * summand
    sum_denom += summand
    summand = summand*(1 + (a_plus_b) / k) * (1 + k) / (1 + a_plus_1 / k)
    digamma_ab += 1./(a_plus_b + k)
    digamma_a += 1./(a_plus_1 + k)
    k += 1
    summand = summand * (x / k)

    args = (summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, digamma_a, x) 
    
    return args



## define partial derivative wrt a
def _betainc_dda(a_dot, primal_out, a, b, x):
    a = jnp.asarray(a)
    b = jnp.asarray(a)
    x = jnp.asarray(x)
    
    digamma_a = jax.scipy.special.digamma(a)
    digamma_ab = jax.scipy.special.digamma(a+b)
    
    threshold = 1e-10

    a_plus_b = a + b
    a_plus_1 = a + 1.
    
    digamma_a = digamma_a + 1./a
    
    prefactor = jnp.power(a_plus_1 / a_plus_b, 3)
    sum_numer = (digamma_ab - digamma_a) * prefactor
    sum_denom = prefactor
    summand = prefactor * x * a_plus_b / a_plus_1
    
    k = 1
    digamma_ab = digamma_ab + 1./a_plus_b
    digamma_a = digamma_a + 1./a_plus_1
    
    ### ----- 6/9/22, 4:03 PM something in the while loop changing the types?
    out = jax.lax.while_loop(
        lambda args: jnp.any(jnp.abs(args[0]) > 1e-10),
        _betainc_dda_while,
        (summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, digamma_a, x)
    )
    
    summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, digamma_a, x = out
    
    
    return _betainc(a, b, x) * (log(x) + sum_numer / sum_denom)*a_dot


def _betainc_ddb_while(args):
    summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, x = args
    
    sum_numer += digamma_ab * summand
    sum_denom += summand

    summand = summand*(1 + (a_plus_b) / k) * (1 + k) / (1 + a_plus_1 / k)
    digamma_ab += 1./(a_plus_b + k)
    k +=1
    summand = summand * x / k

    args = (summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, x)
    
    return args
        

## define partial derivative wrt b
def _betainc_ddb(b_dot, primal_out, a, b, x):
    
    digamma_b = jax.scipy.special.digamma(b)
    digamma_ab = jax.scipy.special.digamma(a+b)
    
    threshold = 1e-10
    
    a_plus_b = a + b
    a_plus_1 = a + 1.
    
    prefactor = jnp.power(a_plus_1 / a_plus_b, 3)
    
    sum_numer = digamma_ab * prefactor
    sum_denom = prefactor
    summand = prefactor * x * a_plus_b / a_plus_1
    
    k = 1
    digamma_ab = digamma_ab + 1./a_plus_b
    
    out = jax.lax.while_loop(
        lambda args: jnp.any(jnp.abs(args[0]) > 1e-10),
        _betainc_ddb_while,
        (summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, x)
    )
    
    summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, x = out
        
    
    return _betainc(a, b, x) * (log(1 - x) - digamma_b + sum_numer / sum_denom)*b_dot




def betainc_gradx(g, primal_out, a, b, x):
    lbeta = gammaln(a) + gammaln(b) - gammaln(a + b)
    partial_x = exp((b - 1) * log1p(-x) +
                  (a - 1) * log(x) - lbeta)
    return partial_x * g

_betainc.defjvps(_betainc_dda, _betainc_ddb, betainc_gradx)

## these work
print(grad(_betainc, 0)(2., 1., .6)) 
print(grad(_betainc, 1)(2., 1., .6))
print(grad(_betainc, 2)(2., 1., .6))

Here’s the numpyro data simulation and model I’m using for testing:

from jax.random import PRNGKey

X_raw = dist.Beta(.75*10, (1-.75)*10).sample(PRNGKey(10), (100,))
x_round = round(X_raw*10)

def f(mu, k):
    a = mu*k
    b = (1.-mu)*k
    
    responses = jnp.linspace(0,10, num=11)
    lower = jnp.clip((responses/10.) - .05, 0., 1.)
    upper = jnp.clip((responses/10.) + .05, 0., 1.)

    prob_resps = _betainc(a, b, upper) - _betainc(a, b, lower)
    
    return(prob_resps)


def mymodel_round(x=None):
    mu = numpyro.sample("mu", dist.Beta(1,1)) # noise parameter
    k = numpyro.sample("k", dist.HalfCauchy(10)) # noise parameter

    resp_probs = f(mu,k)
    
    with numpyro.plate("data", x.shape[0]):

        xhat = numpyro.sample("xhat", dist.Categorical(probs=resp_probs), obs=x)

And this produces the error

kernel = NUTS(mymodel_round, target_accept_prob=.80)

mcmc = MCMC(kernel, 
               num_warmup=1_000, 
               num_samples=1_000, 
               num_chains=1)

mcmc.run(random.PRNGKey(0), x_round)

Appreciate any help!

no idea but i know in some cases you need to be careful to broadcast correctly (unless this has since been changed), e.g. see here