Trouble with jax.lax.while_loop() in gradient

This is probably more of a Jax problem but I’m only encountering the issue in the context of numpyro sampling and this group has been so helpful I thought I’d try here too.

I’m fumbling my way through implementing some custom jvp for the regularized incomplete beta function by attempting to port over some code from the Stan project (this and this). I’m trying to model a survey where people were asked to report probabilities. The underlying cognitive model involves a Beta distribution, but the responses are often rounded. So the efforts to implement the Beta CDF relate to the rounding.

The functions use a while loop that I am trying to implement with jax.lax.while_loop(). My implementations seem to work in some contexts but not others. When I try to run this inside a numpyro MCMC sampling context, I get the following error:

body_fun output and input must have identical types, got
('ShapedArray(float32[11])', 'DIFFERENT ShapedArray(float32[11]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float32[11]) vs. ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(int32[], weak_type=True)', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[11])').

I gather I am somehow transforming the types inside my body_fun (_betainc_dda_while below) but I can’t seem to figure out where. Hoping for some help debugging or even reading this error message, I can’t see where the issue is arising.

here’s my code:

import jax
import jax.nump as jnp
from jax import custom_jvp
from jax.scipy.special import gammaln
from jax.numpy import log1p, log, exp

@custom_jvp
def _betainc(a,b,x):
    return jax.scipy.special.betainc(a,b,x) # assign my own



def _betainc_dda_while(args):
    summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, digamma_a, x = args
    
    sum_numer = sum_numer + (digamma_ab - digamma_a) * summand
    sum_denom += summand
    summand = summand*(1 + (a_plus_b) / k) * (1 + k) / (1 + a_plus_1 / k)
    digamma_ab += 1./(a_plus_b + k)
    digamma_a += 1./(a_plus_1 + k)
    k += 1
    summand = summand * (x / k)

    args = (summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, digamma_a, x) 
    
    return args



## define partial derivative wrt a
def _betainc_dda(a_dot, primal_out, a, b, x):
    a = jnp.asarray(a)
    b = jnp.asarray(a)
    x = jnp.asarray(x)
    
    digamma_a = jax.scipy.special.digamma(a)
    digamma_ab = jax.scipy.special.digamma(a+b)
    
    threshold = 1e-10

    a_plus_b = a + b
    a_plus_1 = a + 1.
    
    digamma_a = digamma_a + 1./a
    
    prefactor = jnp.power(a_plus_1 / a_plus_b, 3)
    sum_numer = (digamma_ab - digamma_a) * prefactor
    sum_denom = prefactor
    summand = prefactor * x * a_plus_b / a_plus_1
    
    k = 1
    digamma_ab = digamma_ab + 1./a_plus_b
    digamma_a = digamma_a + 1./a_plus_1
    
    ### ----- 6/9/22, 4:03 PM something in the while loop changing the types?
    out = jax.lax.while_loop(
        lambda args: jnp.any(jnp.abs(args[0]) > 1e-10),
        _betainc_dda_while,
        (summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, digamma_a, x)
    )
    
    summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, digamma_a, x = out
    
    
    return _betainc(a, b, x) * (log(x) + sum_numer / sum_denom)*a_dot


def _betainc_ddb_while(args):
    summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, x = args
    
    sum_numer += digamma_ab * summand
    sum_denom += summand

    summand = summand*(1 + (a_plus_b) / k) * (1 + k) / (1 + a_plus_1 / k)
    digamma_ab += 1./(a_plus_b + k)
    k +=1
    summand = summand * x / k

    args = (summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, x)
    
    return args
        

## define partial derivative wrt b
def _betainc_ddb(b_dot, primal_out, a, b, x):
    
    digamma_b = jax.scipy.special.digamma(b)
    digamma_ab = jax.scipy.special.digamma(a+b)
    
    threshold = 1e-10
    
    a_plus_b = a + b
    a_plus_1 = a + 1.
    
    prefactor = jnp.power(a_plus_1 / a_plus_b, 3)
    
    sum_numer = digamma_ab * prefactor
    sum_denom = prefactor
    summand = prefactor * x * a_plus_b / a_plus_1
    
    k = 1
    digamma_ab = digamma_ab + 1./a_plus_b
    
    out = jax.lax.while_loop(
        lambda args: jnp.any(jnp.abs(args[0]) > 1e-10),
        _betainc_ddb_while,
        (summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, x)
    )
    
    summand, sum_numer, sum_denom, a_plus_b, k, a_plus_1, digamma_ab, x = out
        
    
    return _betainc(a, b, x) * (log(1 - x) - digamma_b + sum_numer / sum_denom)*b_dot




def betainc_gradx(g, primal_out, a, b, x):
    lbeta = gammaln(a) + gammaln(b) - gammaln(a + b)
    partial_x = exp((b - 1) * log1p(-x) +
                  (a - 1) * log(x) - lbeta)
    return partial_x * g

_betainc.defjvps(_betainc_dda, _betainc_ddb, betainc_gradx)

## these work
print(grad(_betainc, 0)(2., 1., .6)) 
print(grad(_betainc, 1)(2., 1., .6))
print(grad(_betainc, 2)(2., 1., .6))

Here’s the numpyro data simulation and model I’m using for testing:

from jax.random import PRNGKey

X_raw = dist.Beta(.75*10, (1-.75)*10).sample(PRNGKey(10), (100,))
x_round = round(X_raw*10)

def f(mu, k):
    a = mu*k
    b = (1.-mu)*k
    
    responses = jnp.linspace(0,10, num=11)
    lower = jnp.clip((responses/10.) - .05, 0., 1.)
    upper = jnp.clip((responses/10.) + .05, 0., 1.)

    prob_resps = _betainc(a, b, upper) - _betainc(a, b, lower)
    
    return(prob_resps)


def mymodel_round(x=None):
    mu = numpyro.sample("mu", dist.Beta(1,1)) # noise parameter
    k = numpyro.sample("k", dist.HalfCauchy(10)) # noise parameter

    resp_probs = f(mu,k)
    
    with numpyro.plate("data", x.shape[0]):

        xhat = numpyro.sample("xhat", dist.Categorical(probs=resp_probs), obs=x)

And this produces the error

kernel = NUTS(mymodel_round, target_accept_prob=.80)

mcmc = MCMC(kernel, 
               num_warmup=1_000, 
               num_samples=1_000, 
               num_chains=1)

mcmc.run(random.PRNGKey(0), x_round)

Appreciate any help!

no idea but i know in some cases you need to be careful to broadcast correctly (unless this has since been changed), e.g. see here

Yes, looks like that was at least part of the issue, I was able to get some help from the jax folks and resolve this error. Now unfortunately I get some new errors so I haven’t yet been able to get this working.

In addition to the error I mentioned in the linked thread, when I try to run the model it can’t initialize parameters. Even passing in initial parameters that “work” for the functions outside the inference I get the initialization error

Could you try to use tfp.math.betainc?

import tensorflow_probability.substrates.jax as tfp

Thanks @fehiepsi, I had done some reading on that function that left me thinking it also didn’t have gradients implemented for a,b arguments but it seems I was wrong about that? I think there’s some progress here, but now I get an error saying it cannot find valid initial parameters. I get this error even when I manually specify valid initial parameters.

As things have gotten very long here, here’s a new minimal repro:

import jax
import jax.nump as jnp
from jax.random import PRNGKey

## simulate some data
X_raw = dist.Beta(.75*10, (1-.75)*10).sample(PRNGKey(10), (100,))
x_round = round(X_raw*10)

## define the model
import tensorflow_probability.substrates.jax as tfp

def f(mu, k):
    a = mu*k
    b = (1.-mu)*k
    
    responses = jnp.linspace(0,10, num=11)
    lower = jnp.clip((responses/10.) - .05, 0., 1.)
    upper = jnp.clip((responses/10.) + .05, 0., 1.)
    
    prob_resps = tfp.math.betainc(a, b, upper) - tfp.math.betainc(a, b, lower)
    
    return(prob_resps)


def mymodel_round(x=None):
    mu = numpyro.sample("mu", dist.Beta(1,1)) # noise parameter
    k = numpyro.sample("k", dist.HalfCauchy(10)) # noise parameter

    resp_probs = f(mu,k)
    resp_probs = resp_probs/jnp.sum(resp_probs)
    
    with numpyro.plate("data", x.shape[0]):

        xhat = numpyro.sample("xhat", dist.Categorical(probs=resp_probs), obs=x)


## sample

kernel = NUTS(mymodel_round, init_strategy = init_to_value(values={"mu":.33, "k":5}))

mcmc = MCMC(kernel, 
               num_warmup=1_000, 
               num_samples=1_000, 
               num_chains=1)

mcmc.run(random.PRNGKey(0), x_round)

The result:

RuntimeError: Cannot find valid initial parameters. Please check your model again.

I guess you might need to avoid extreme values at 0. or 1. when computing betainc.

1 Like

This works! You are my hero @fehiepsi!

For anyone curious, changing jnp.clip(..., 0., 1.)jnp.clip(..., .001, .999) seems to have solved the problem.