Making a for-loop more efficient

Yes, it is better to use scan when you need gradient. It seems that some of your scan input/output doesn’t have expected dtype. If you can’t share a minimal reproducible code, you can reach out to jax devs to interpret the error message. Have you tried to create a jax function with just scan to isolate the issue?

Thanks a lot for the input! I’ve been working on the problem some more today but I did not solve the issue yet.

When I use jax.test_util’s check_grads function I get the following error, which is not referring back to my own code anywhere.
TypeError: primal and tangent arguments to jax.jvp do not match; dtypes must be equal, or in case of int/bool primal dtype the tangent dtype must be float0.Got primal dtype int32 and so expected tangent dtype [('float0', 'V')], but got tangent dtype int32 instead.

And when I try to use the code in a Numpyro model to sample from or find the MAP of. I get the following error:
KeyError: dtype([('float0', 'V')])

A minimal example, here’s the code:

import jax
import jax.numpy as jnp
from jax import vmap, ops
import numpy as np
from jax.experimental.ode import odeint

def dz_dt(z, t, A, constants, intercept_stocks, 
    aux_predictors_mat, intercept_aux): 
    variables, stored_vars = aux_scan(z, constants, aux_predictors_mat, intercept_aux)
    return jnp.matmul(A, variables) + intercept_stocks

def aux_scan(z, constants, aux_predictors_mat, intercept_aux):
    n = aux_predictors_mat.shape[0]
    xs = jnp.arange(0, n-1)    # Same error if I use numpy.arange()
    auxiliaries = jnp.array([0.0] * n)   # Same error if I use numpy.arange()
    variables_init = jnp.append(jnp.append(auxiliaries, z), constants) 
    
    def scan_body(variables, num_aux): #carry, predictors_mat):
        variables = ops.index_update(variables, num_aux, intercept_aux[num_aux] +
                                         jnp.matmul(variables, aux_predictors_mat[num_aux]))
        return variables, variables 

    return jax.lax.scan(scan_body, init=variables_init, xs=xs)

ts = jnp.array(np.array([0.0, 3.0] + list(np.linspace(0, 60, int(60/6) + 1)[1:])))

z_init = jnp.array([[-1.0687675 , -0.39805314, -0.27391016, -0.93831366],
                       [-0.8567926 ,  0.344659  , -0.56589156, -0.0977652 ]])
constants = jnp.array([[ 1.2550328, -0.7836494],
                      [-0.7967919,  0.8055391]])

K = jnp.array([[ 1.,  0.,  0., -1.,  1.,  1.,  0.,  0.],
             [ 0.,  0.,  0.,  0., -1., -1., -1., -1.],
             [ 1.,  1.,  1.,  0.,  0.,  1.,  1.,  1.],
             [ 1.,  1.,  1.,  0.,  1.,  1.,  1.,  1.]])

params_aux_mat = jnp.array([[0.  , 0.  , 0.01, 0.  , 0.  , 0.01, 0.01, 0.01],
                             [0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.01, 0.01]])

intercept_aux = jnp.array([0, 0])
intercept_stocks = jnp.array([0] * 4)

ode_sol = vmap(lambda y0, constants: odeint(dz_dt, y0, ts, K, 
                                  constants, 
                                  intercept_stocks, 
                                  params_aux_mat,   
                                  intercept_aux))(z_init, constants)

To call the function with vmap() just works, but check_grads gives the above mentioned error.

from jax.test_util import check_grads

check_grads(dz_dt, (z_init, ts, K, constants, intercept_stocks, params_aux_mat, intercept_aux), 1, eps=1e-2)

I’m not sure how check_grads works but it seems that dz_dt does not work on batch arrays z_init, constants. Do you mean to use vmap there? I changed

intercept_aux = jnp.array([0., 0.])
intercept_stocks = jnp.array([0.] * 4)

and it seems to work with

target = lambda y0, constants: odeint(
    dz_dt, y0, ts, K, constants, 
    intercept_stocks, params_aux_mat, intercept_aux).sum()
jax.grad(target)(z_init[0], constants[0]), jax.jacobian(jax.vmap(target))(z_init, constants)
1 Like

Wow, it seems like indeed turning the zeros into floats rather than integers helped the code run! I’m not getting the correct coefficients back yet but that’s likely a different error in my code somewhere. Thanks!

Is it by any chance a familiar problem that sometimes the object returned by MCMC(NUTS(…),) .run() takes forever to load when using, for instance, .get_samples(), or print_summary()? Any hints as to what it means when this happens? Did my sampling fail?

I have this with the above model when I run NUTS but not when I run svi Minimize. In the latter case I just find incorrect coefficients.

By the way, you asked if I meant to use vmap() I’m a bit confused now as you recommended this to me earlier here:

Am I not using it correctly?

Never mind. In your reproducible code, you applied dz/dt for a batch of inputs. I think you need to vmap dz/dt there (in check grads line).

Thanks! I’ll just leave it as is.
Just to not let the other question die out:

Also, although I do get param estimates from svi / Minimize() (which are incorrect), when I then do guide.get_posterior(params) to check the uncertainty I again get the error I mentioned earlier, i.e.:
TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

NUTS should not cause infinite loop. I guess your odeint does not stop. You might want to set mxsteps like in ode example (I faced infinite loop in the past and submitted that flag to jax - hopefully it also resolves your case).

can’t apply forward-mode autodiff (jvp) to a custom_vjp function

I think minimize use forward-mode but ode has custom vjp - so this usage case is not supported? It would be nice to point out where the error happens, at least some tracebacks - so we can help better. (It is always tricky to identify the error based on the error message alone - especially when the error message does not come from our framework). For example, which lines in get_posterior causes the issue might give us more concrete ideas - but it is best to have some reproducible code (by just removing all unnecessary stuffs - for example, you can just make dz/dt behaves like an identity function or remove most of the latent samples…)

Hi fehiepsi! Thanks a lot for your input again. Unfortunately the issue with NUTS wasn’t solved by adding the mxsteps option. I did figure out that the problem is not related to scan() as the infinite load persists when I reintroduce the for-loop. Any other suggestions?

Regarding the error message for minimize(), i’ll share the error message below. If that won’t do it, I’ll try another minimal example. I’ll just have to add a model function to my previous example in this topic.

Here’s the “short” error that excludes Jax’s internal frames (the longer version was too long):

FilteredStackTrace                        Traceback (most recent call last)
/var/folders/58/rcr3xzjn6bscyr8llj4t161m0000gn/T/ipykernel_3097/1815523977.py in <module>
      1 #Predictive(model, guide, num_samples=2000)
----> 2 guide.quantiles(params, 0.5)

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/autoguide.py in quantiles(self, params, quantiles)
   1110     def quantiles(self, params, quantiles):
-> 1111         transform = self.get_transform(params)
   1112         quantiles = jnp.array(quantiles)[..., None]

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/autoguide.py in get_transform(self, params)
   1082         loc = params["{}_loc".format(self.prefix)]
-> 1083         precision = hessian(loss_fn)(loc)
   1084         scale_tril = cholesky_of_inverse(precision)

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/autoguide.py in loss_fn(z)
   1079             params1["{}_loc".format(self.prefix)] = z
-> 1080             return self._loss_fn(params1)
   1081 

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/autoguide.py in loss_fn(params)
   1062             return Trace_ELBO().loss(
-> 1063                 random.PRNGKey(0), params, self.model, self, *args, **kwargs
   1064             )

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/elbo.py in loss(self, rng_key, param_map, model, guide, *args, **kwargs)
     53         return self.loss_with_mutable_state(
---> 54             rng_key, param_map, model, guide, *args, **kwargs
     55         )["loss"]

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/elbo.py in loss_with_mutable_state(self, rng_key, param_map, model, guide, *args, **kwargs)
    149         if self.num_particles == 1:
--> 150             elbo, mutable_state = single_particle_elbo(rng_key)
    151             return {"loss": -elbo, "mutable_state": mutable_state}

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/elbo.py in single_particle_elbo(rng_key)
    124             model_log_density, model_trace = log_density(
--> 125                 seeded_model, args, kwargs, params
    126             )

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/util.py in log_density(model, model_args, model_kwargs, params)
     52     model = substitute(model, data=params)
---> 53     model_trace = trace(model).get_trace(*model_args, **model_kwargs)
     54     log_joint = jnp.zeros(())

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    164         """
--> 165         self(*args, **kwargs)
    166         return self.trace

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     86         with self:
---> 87             return self.fn(*args, **kwargs)
     88 

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     86         with self:
---> 87             return self.fn(*args, **kwargs)
     88 

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     86         with self:
---> 87             return self.fn(*args, **kwargs)
     88 

~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     86         with self:
---> 87             return self.fn(*args, **kwargs)
     88 

/var/folders/58/rcr3xzjn6bscyr8llj4t161m0000gn/T/ipykernel_3097/3862560390.py in model(data, prior_distribution, dir_mat, s, sample_vars, data_bl, aux_predictors_mat, intercept_aux, predictive)
     29                                                    intercept_aux, rtol=1e-6, atol=1e-5,
---> 30                                                    mxstep=1000))(data_bl_stocks, data_bl_constants) 
     31         y_pred = jnp.concatenate(z_array)

/var/folders/58/rcr3xzjn6bscyr8llj4t161m0000gn/T/ipykernel_3097/3862560390.py in <lambda>(y0, constants)
     29                                                    intercept_aux, rtol=1e-6, atol=1e-5,
---> 30                                                    mxstep=1000))(data_bl_stocks, data_bl_constants) 
     31         y_pred = jnp.concatenate(z_array)

~/Anaconda/anaconda3/lib/python3.7/site-packages/jax/experimental/ode.py in odeint(func, y0, t, rtol, atol, mxstep, *args)
    172   converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
--> 173   return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *args, *consts)
    174 

~/Anaconda/anaconda3/lib/python3.7/site-packages/jax/experimental/ode.py in _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args)
    178   func = ravel_first_arg(func, unravel)
--> 179   out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
    180   return jax.vmap(unravel)(out)

~/Anaconda/anaconda3/lib/python3.7/site-packages/jax/experimental/ode.py in _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args)
    217 def _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args):
--> 218   ys = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
    219   return ys, (ys, ts, args)

FilteredStackTrace: TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

Oh I see, the error message is really helpful. Thanks! It says that hessian (which combines forward-mode differentiation and reverse-mode differentiation: hessian(f)(x) = jacfwd(jacrev(f))(x)) does not work with custom_vjp (in odeint). I guess we can use jacrev(jacrev(f))(x) there. We can add a configuration in Laplace Approximation to control this behavior. Could you test it first? If it resolves the issue, then could you make a feature request for this.

Any other suggestions?

I couldn’t think of anything else on the top of my head. I guess I need some reproducible code for this one.

1 Like

hi @fehiepsi, that sounds like a good solution!

The following works for me, while jacfwd() gives the mentioned error.

f = lambda y0 : odeint(dz_dt_loop, y0, ts, K, 
                       constants[0], 
                       intercept_stocks, 
                       params_aux_mat,   
                       intercept_aux,  rtol=1e-6, 
                       atol=1e-5, mxstep=1000) 

jax.jacrev(f)(z_init[0])

A double application of jacrev() here gives me the following error but I’m guessing that’s unrelated so I opened an issue on the NumPyro github.

FilteredStackTrace: ValueError: vmap has mapped output but out_axes is None

Here’s the issue: [Feature request] hessian(f)(x) = jacrev(jacrev(f))(x)) · Issue #1195 · pyro-ppl/numpyro · GitHub.

I hope this is what you meant!

So even with doubling jacrev, there is still an issue. I wonder if you can make reproducible code so we can ask jax dev (it is hard for them to help without reproducible code)? I think that would be the easiest solution because this debugging thread has been quite long. For reproducible code, you can:

  • remove unnecessary operators
  • remove unnecessary latent variables
  • create fake and small data
  • remove unnecessary arguments for intermediate functions
  • remove unnecessary intermediate functions
  • make reproducible short, in under 20 lines of code
    You can remove one by one as long as errors still happen.

OK thanks @fehiepsi, I’ve made some progress again. Firstly, it seems that the infinite loop issue with NUTS was resolved by turning progress_bar=True on. Strangely…

Then regarding the minimal example to reproduce the forward mode differentiation error. The last line should cause the error.

import jax
import jax.numpy as jnp
from jax import ops, random
import numpy as np
from jax.experimental.ode import odeint
import numpyro
import numpyro.distributions as dist
from numpyro.infer.autoguide import AutoLaplaceApproximation   
from numpyro.infer import SVI, Trace_ELBO

def dy_dt(z, t, A):
    return jnp.matmul(A, z) 

def model(data, ts, ind_mat):
    theta = numpyro.sample("theta", dist.Normal(loc=jnp.array([0.] * 2),
                                                scale=jnp.array([1.0]) * 2),)
    A = jnp.array(np.zeros([ind_mat.shape[0], ind_mat.shape[1]]))
    A = ops.index_update(A, ind_mat, theta)
    y_pred = odeint(dy_dt, fake_data[0], ts, A, rtol=1e-6, atol=1e-5, mxstep=1000) 
    sigma = numpyro.sample("sigma", dist.LogNormal(-1, 1).expand([2]))
    numpyro.sample("y_pred", dist.Normal(y_pred[1:], sigma), obs=fake_data[1:]) 

def get_MAP(data, ts, ind_mat):
    optimizer = numpyro.optim.Minimize(method="BFGS") 
    guide = AutoLaplaceApproximation(model)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    init_state = svi.init(random.PRNGKey(0), data, ts, ind_mat)
    optimal_state, loss = svi.update(init_state, data, ts, ind_mat)
    params = svi.get_params(optimal_state) 
    return params, guide, svi, optimal_state, loss 

ts = jnp.array(list(np.linspace(0, 240, 81)))
y_init = jnp.array([-0.06, -0.01])
K = jnp.array([[ 0., 0.01],
               [ 0.01,  0.]]) 
ind_mat = (K != 0)
fake_data = odeint(dy_dt, y_init, ts, K) 

params, guide, svi, optimal_state, loss = get_MAP(fake_data, ts, ind_mat)
print(params)
print(guide.quantiles(params, 0.91))

infinite loop issue with NUTS was resolved by turning progress_bar=True

This is likely an XLA issue, which is sometimes very slow in some systems.

The last line should cause the error.

This seems to be fixed in this pr. You can test it with

!pip install -q git+https://github.com/pyro-ppl/numpyro.git@refs/pull/1196/head
guide = AutoLaplaceApproximation(model,
    get_precision=lambda f: jax.jacobian(jax.jacobian(f)))
2 Likes

Thanks! Nice that it was already fixed.

Sorry for this extremely noob question but: the first line is meant to pull the fixed numpyro version, right? So I run that in terminal, and then in my model when I assign the guide I add the get_precision setting?

!pip install -q git+https://github.com/pyro-ppl/numpyro.git@refs/pull/1196/head

I’m asking because when I do run this line in terminal it installs a few requirements but does not allow me to use “get_precision”

It tells me:
TypeError: __init__() got an unexpected keyword argument 'get_precision'

That’s right, I tested it with colab but not sure if the installation command works with terminal. You might want to uninstall numpyro first and remove the -q option to get more info messages.

thanks @fehiepsi , I also got it to work with colab! :slight_smile: excellent.

I guess even this long thread can come to an end :stuck_out_tongue: . Btw, did you notice anything odd about my code, as I haven’t been able to find the parameters back that I create the toy data with (also if I do repeated updates, or use svi.run() instead.

I think you can use svi.run. Your code looks great to me. FYI, when the pr (which is under reviewed) is merged, get_precision name/pattern can be changed so pls follow the description in the docs.

1 Like