Making a for-loop more efficient

Hi Everyone,

While sampling with NUTS, I’m solving a system of differential algebraic equations using jax’s odeint(). To this end, I first have to first solve the algebraic equations before solving the differential equations using jnp.matmul(). The algebraic equations are dependent on one another, so I solve for x1 first, I then need x1 to solve for x2, etc.

Because these equations depend on one another, I have only succeeded at implementing them using a for-loop. I tried vmap() but this would use the previous value for x1, not the new value of x1 to calculate x2 (in my code it’s 0 as I initialized to a zeros array).

For illustration, here’s the code I’m using. Unfortunately adding these algebraic equations has made the code four times slower as opposed to just solving for the DEs. “s” is a NameSpace with settings.

Any suggestions for how to speed up this code?

def dz_dt(z, t, A, constants, intercept_stocks):
""" Solve the Ordinary Differential Algebraic Equations

## First solve the auxiliaries
auxiliaries = jnp.zeros([len(s.auxiliaries_sdm)])
variables = jnp.append(jnp.append(auxiliaries, z), constants)

for i, aux in enumerate(s.auxiliaries_sdm):
    curr_val = s.par_dict[aux]["Intercept"] + sum([s.par_dict[aux][pred] * variables[s.sdm_vars.index(pred)] for pred in list(s.par_dict[aux].keys())[1:]])
    variables = ops.index_update(variables, i, curr_val)

# Finally solve for the stocks
return jnp.matmul(A, variables) + intercept_stocks

I’m calling this function using:
vmap(lambda y0,constants: odeint(dz_dt, y0, s.ts[1:], K, constants, intercept))(data_bl_stocks, data_bl_constants)

Jax provides several dependent control flow ops in jax.lax, including jax.lax.fori_loop, which is equivalent to the following Python loop:

def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val

You might try rewriting your loop to use fori_loop or scan, which will definitely speed up compilation and may speed up the code itself.

1 Like

Thanks a lot! fori_loop() seemed to speed up my code four-fold. Unfortunately it does not seem to work with NUTS due to issues with autodiff.
I’m getting similar results with scan(), which should work with autodiff, but still dealing with some TypeError. Jax is difficult to debug sometimes. I’ll get back to you when I figure it out.

you may (?) need to use the forward_mode_differentiation=True flag:

Thanks for your suggestion! You mean specifically for when I’d use fori_loop() correct? Honestly I’d be very happy to use fori_loop as I understand better how it works than scan() and I haven’t figured the issue with scan() out yet.

I seem to have two issues when using fori_loop():

(1) I would like to also use the numpyro.optim.Minimize() function, to quickly find MAP values for some model comparisons. This function does not seem to have the option to run forward_mode_diff and gives me this error. Is this not an option at all with fori_loop? (It’s not a huge problem, because Minimize() still runs in 15 minutes with my original for-looped implementation).

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.

(2) When I run NUTS with the forward_mode_differentiation=True flag, I get the following error which may be related to the fact that I’m using odeint()?

TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

Any work arounds, perhaps, or would it be best to focus my efforts on scan()?

I think your function will work with scan if it works with fori_loop. It is better to isolate the issue first by writing a function using your ode implementation without using any numpyro code. Then test if that function works under jax transforms like jax.grad, jax.vmap.

Thanks! Would you suggest opting for scan() then, rather than fori_loop?

And good advice! This is how I’ve been approaching it, but the error I’m getting seems difficult to understand,

TypeError: primal and tangent arguments to jax.jvp do not match; dtypes must be equal, or in case of int/bool primal dtype the tangent dtype must be float0.Got primal dtype int32 and so expected tangent dtype [('float0', 'V')], but got tangent dtype int32 instead.

Somehow I’m getting this with the scan() but not the fori_loop() implementation. It’s not clear where in the code this is happening exactly. I thought it was maybe related to me creating an jnp array of zeros, but doing that with numpy instead didn’t solve the issue.

Yes, it is better to use scan when you need gradient. It seems that some of your scan input/output doesn’t have expected dtype. If you can’t share a minimal reproducible code, you can reach out to jax devs to interpret the error message. Have you tried to create a jax function with just scan to isolate the issue?

Thanks a lot for the input! I’ve been working on the problem some more today but I did not solve the issue yet.

When I use jax.test_util’s check_grads function I get the following error, which is not referring back to my own code anywhere.
TypeError: primal and tangent arguments to jax.jvp do not match; dtypes must be equal, or in case of int/bool primal dtype the tangent dtype must be float0.Got primal dtype int32 and so expected tangent dtype [('float0', 'V')], but got tangent dtype int32 instead.

And when I try to use the code in a Numpyro model to sample from or find the MAP of. I get the following error:
KeyError: dtype([('float0', 'V')])

A minimal example, here’s the code:

import jax
import jax.numpy as jnp
from jax import vmap, ops
import numpy as np
from jax.experimental.ode import odeint

def dz_dt(z, t, A, constants, intercept_stocks, 
    aux_predictors_mat, intercept_aux): 
    variables, stored_vars = aux_scan(z, constants, aux_predictors_mat, intercept_aux)
    return jnp.matmul(A, variables) + intercept_stocks

def aux_scan(z, constants, aux_predictors_mat, intercept_aux):
    n = aux_predictors_mat.shape[0]
    xs = jnp.arange(0, n-1)    # Same error if I use numpy.arange()
    auxiliaries = jnp.array([0.0] * n)   # Same error if I use numpy.arange()
    variables_init = jnp.append(jnp.append(auxiliaries, z), constants) 
    def scan_body(variables, num_aux): #carry, predictors_mat):
        variables = ops.index_update(variables, num_aux, intercept_aux[num_aux] +
                                         jnp.matmul(variables, aux_predictors_mat[num_aux]))
        return variables, variables 

    return jax.lax.scan(scan_body, init=variables_init, xs=xs)

ts = jnp.array(np.array([0.0, 3.0] + list(np.linspace(0, 60, int(60/6) + 1)[1:])))

z_init = jnp.array([[-1.0687675 , -0.39805314, -0.27391016, -0.93831366],
                       [-0.8567926 ,  0.344659  , -0.56589156, -0.0977652 ]])
constants = jnp.array([[ 1.2550328, -0.7836494],
                      [-0.7967919,  0.8055391]])

K = jnp.array([[ 1.,  0.,  0., -1.,  1.,  1.,  0.,  0.],
             [ 0.,  0.,  0.,  0., -1., -1., -1., -1.],
             [ 1.,  1.,  1.,  0.,  0.,  1.,  1.,  1.],
             [ 1.,  1.,  1.,  0.,  1.,  1.,  1.,  1.]])

params_aux_mat = jnp.array([[0.  , 0.  , 0.01, 0.  , 0.  , 0.01, 0.01, 0.01],
                             [0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.01, 0.01]])

intercept_aux = jnp.array([0, 0])
intercept_stocks = jnp.array([0] * 4)

ode_sol = vmap(lambda y0, constants: odeint(dz_dt, y0, ts, K, 
                                  intercept_aux))(z_init, constants)

To call the function with vmap() just works, but check_grads gives the above mentioned error.

from jax.test_util import check_grads

check_grads(dz_dt, (z_init, ts, K, constants, intercept_stocks, params_aux_mat, intercept_aux), 1, eps=1e-2)

I’m not sure how check_grads works but it seems that dz_dt does not work on batch arrays z_init, constants. Do you mean to use vmap there? I changed

intercept_aux = jnp.array([0., 0.])
intercept_stocks = jnp.array([0.] * 4)

and it seems to work with

target = lambda y0, constants: odeint(
    dz_dt, y0, ts, K, constants, 
    intercept_stocks, params_aux_mat, intercept_aux).sum()
jax.grad(target)(z_init[0], constants[0]), jax.jacobian(jax.vmap(target))(z_init, constants)
1 Like

Wow, it seems like indeed turning the zeros into floats rather than integers helped the code run! I’m not getting the correct coefficients back yet but that’s likely a different error in my code somewhere. Thanks!

Is it by any chance a familiar problem that sometimes the object returned by MCMC(NUTS(…),) .run() takes forever to load when using, for instance, .get_samples(), or print_summary()? Any hints as to what it means when this happens? Did my sampling fail?

I have this with the above model when I run NUTS but not when I run svi Minimize. In the latter case I just find incorrect coefficients.

By the way, you asked if I meant to use vmap() I’m a bit confused now as you recommended this to me earlier here:

Am I not using it correctly?

Never mind. In your reproducible code, you applied dz/dt for a batch of inputs. I think you need to vmap dz/dt there (in check grads line).

Thanks! I’ll just leave it as is.
Just to not let the other question die out:

Also, although I do get param estimates from svi / Minimize() (which are incorrect), when I then do guide.get_posterior(params) to check the uncertainty I again get the error I mentioned earlier, i.e.:
TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

NUTS should not cause infinite loop. I guess your odeint does not stop. You might want to set mxsteps like in ode example (I faced infinite loop in the past and submitted that flag to jax - hopefully it also resolves your case).

can’t apply forward-mode autodiff (jvp) to a custom_vjp function

I think minimize use forward-mode but ode has custom vjp - so this usage case is not supported? It would be nice to point out where the error happens, at least some tracebacks - so we can help better. (It is always tricky to identify the error based on the error message alone - especially when the error message does not come from our framework). For example, which lines in get_posterior causes the issue might give us more concrete ideas - but it is best to have some reproducible code (by just removing all unnecessary stuffs - for example, you can just make dz/dt behaves like an identity function or remove most of the latent samples…)

Hi fehiepsi! Thanks a lot for your input again. Unfortunately the issue with NUTS wasn’t solved by adding the mxsteps option. I did figure out that the problem is not related to scan() as the infinite load persists when I reintroduce the for-loop. Any other suggestions?

Regarding the error message for minimize(), i’ll share the error message below. If that won’t do it, I’ll try another minimal example. I’ll just have to add a model function to my previous example in this topic.

Here’s the “short” error that excludes Jax’s internal frames (the longer version was too long):

FilteredStackTrace                        Traceback (most recent call last)
/var/folders/58/rcr3xzjn6bscyr8llj4t161m0000gn/T/ipykernel_3097/ in <module>
      1 #Predictive(model, guide, num_samples=2000)
----> 2 guide.quantiles(params, 0.5)

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/ in quantiles(self, params, quantiles)
   1110     def quantiles(self, params, quantiles):
-> 1111         transform = self.get_transform(params)
   1112         quantiles = jnp.array(quantiles)[..., None]

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/ in get_transform(self, params)
   1082         loc = params["{}_loc".format(self.prefix)]
-> 1083         precision = hessian(loss_fn)(loc)
   1084         scale_tril = cholesky_of_inverse(precision)

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/ in loss_fn(z)
   1079             params1["{}_loc".format(self.prefix)] = z
-> 1080             return self._loss_fn(params1)

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/ in loss_fn(params)
   1062             return Trace_ELBO().loss(
-> 1063                 random.PRNGKey(0), params, self.model, self, *args, **kwargs
   1064             )

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/ in loss(self, rng_key, param_map, model, guide, *args, **kwargs)
     53         return self.loss_with_mutable_state(
---> 54             rng_key, param_map, model, guide, *args, **kwargs
     55         )["loss"]

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/ in loss_with_mutable_state(self, rng_key, param_map, model, guide, *args, **kwargs)
    149         if self.num_particles == 1:
--> 150             elbo, mutable_state = single_particle_elbo(rng_key)
    151             return {"loss": -elbo, "mutable_state": mutable_state}

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/ in single_particle_elbo(rng_key)
    124             model_log_density, model_trace = log_density(
--> 125                 seeded_model, args, kwargs, params
    126             )

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/ in log_density(model, model_args, model_kwargs, params)
     52     model = substitute(model, data=params)
---> 53     model_trace = trace(model).get_trace(*model_args, **model_kwargs)
     54     log_joint = jnp.zeros(())

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/ in get_trace(self, *args, **kwargs)
    164         """
--> 165         self(*args, **kwargs)
    166         return self.trace

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/ in __call__(self, *args, **kwargs)
     86         with self:
---> 87             return self.fn(*args, **kwargs)

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/ in __call__(self, *args, **kwargs)
     86         with self:
---> 87             return self.fn(*args, **kwargs)

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/ in __call__(self, *args, **kwargs)
     86         with self:
---> 87             return self.fn(*args, **kwargs)

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/ in __call__(self, *args, **kwargs)
     86         with self:
---> 87             return self.fn(*args, **kwargs)

/var/folders/58/rcr3xzjn6bscyr8llj4t161m0000gn/T/ipykernel_3097/ in model(data, prior_distribution, dir_mat, s, sample_vars, data_bl, aux_predictors_mat, intercept_aux, predictive)
     29                                                    intercept_aux, rtol=1e-6, atol=1e-5,
---> 30                                                    mxstep=1000))(data_bl_stocks, data_bl_constants) 
     31         y_pred = jnp.concatenate(z_array)

/var/folders/58/rcr3xzjn6bscyr8llj4t161m0000gn/T/ipykernel_3097/ in <lambda>(y0, constants)
     29                                                    intercept_aux, rtol=1e-6, atol=1e-5,
---> 30                                                    mxstep=1000))(data_bl_stocks, data_bl_constants) 
     31         y_pred = jnp.concatenate(z_array)

~/Anaconda/anaconda3/lib/python3.7/site-packages/jax/experimental/ in odeint(func, y0, t, rtol, atol, mxstep, *args)
    172   converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
--> 173   return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *args, *consts)

~/Anaconda/anaconda3/lib/python3.7/site-packages/jax/experimental/ in _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args)
    178   func = ravel_first_arg(func, unravel)
--> 179   out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
    180   return jax.vmap(unravel)(out)

~/Anaconda/anaconda3/lib/python3.7/site-packages/jax/experimental/ in _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args)
    217 def _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args):
--> 218   ys = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
    219   return ys, (ys, ts, args)

FilteredStackTrace: TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

Oh I see, the error message is really helpful. Thanks! It says that hessian (which combines forward-mode differentiation and reverse-mode differentiation: hessian(f)(x) = jacfwd(jacrev(f))(x)) does not work with custom_vjp (in odeint). I guess we can use jacrev(jacrev(f))(x) there. We can add a configuration in Laplace Approximation to control this behavior. Could you test it first? If it resolves the issue, then could you make a feature request for this.

Any other suggestions?

I couldn’t think of anything else on the top of my head. I guess I need some reproducible code for this one.

hi @fehiepsi, that sounds like a good solution!

The following works for me, while jacfwd() gives the mentioned error.

f = lambda y0 : odeint(dz_dt_loop, y0, ts, K, 
                       intercept_aux,  rtol=1e-6, 
                       atol=1e-5, mxstep=1000) 


A double application of jacrev() here gives me the following error but I’m guessing that’s unrelated so I opened an issue on the NumPyro github.

FilteredStackTrace: ValueError: vmap has mapped output but out_axes is None

Here’s the issue:

I hope this is what you meant!