Yes, it is better to use scan when you need gradient. It seems that some of your scan input/output doesn’t have expected dtype. If you can’t share a minimal reproducible code, you can reach out to jax devs to interpret the error message. Have you tried to create a jax function with just scan to isolate the issue?
Thanks a lot for the input! I’ve been working on the problem some more today but I did not solve the issue yet.
When I use jax.test_util’s check_grads function I get the following error, which is not referring back to my own code anywhere.
TypeError: primal and tangent arguments to jax.jvp do not match; dtypes must be equal, or in case of int/bool primal dtype the tangent dtype must be float0.Got primal dtype int32 and so expected tangent dtype [('float0', 'V')], but got tangent dtype int32 instead.
And when I try to use the code in a Numpyro model to sample from or find the MAP of. I get the following error:
KeyError: dtype([('float0', 'V')])
A minimal example, here’s the code:
import jax
import jax.numpy as jnp
from jax import vmap, ops
import numpy as np
from jax.experimental.ode import odeint
def dz_dt(z, t, A, constants, intercept_stocks,
aux_predictors_mat, intercept_aux):
variables, stored_vars = aux_scan(z, constants, aux_predictors_mat, intercept_aux)
return jnp.matmul(A, variables) + intercept_stocks
def aux_scan(z, constants, aux_predictors_mat, intercept_aux):
n = aux_predictors_mat.shape[0]
xs = jnp.arange(0, n-1) # Same error if I use numpy.arange()
auxiliaries = jnp.array([0.0] * n) # Same error if I use numpy.arange()
variables_init = jnp.append(jnp.append(auxiliaries, z), constants)
def scan_body(variables, num_aux): #carry, predictors_mat):
variables = ops.index_update(variables, num_aux, intercept_aux[num_aux] +
jnp.matmul(variables, aux_predictors_mat[num_aux]))
return variables, variables
return jax.lax.scan(scan_body, init=variables_init, xs=xs)
ts = jnp.array(np.array([0.0, 3.0] + list(np.linspace(0, 60, int(60/6) + 1)[1:])))
z_init = jnp.array([[-1.0687675 , -0.39805314, -0.27391016, -0.93831366],
[-0.8567926 , 0.344659 , -0.56589156, -0.0977652 ]])
constants = jnp.array([[ 1.2550328, -0.7836494],
[-0.7967919, 0.8055391]])
K = jnp.array([[ 1., 0., 0., -1., 1., 1., 0., 0.],
[ 0., 0., 0., 0., -1., -1., -1., -1.],
[ 1., 1., 1., 0., 0., 1., 1., 1.],
[ 1., 1., 1., 0., 1., 1., 1., 1.]])
params_aux_mat = jnp.array([[0. , 0. , 0.01, 0. , 0. , 0.01, 0.01, 0.01],
[0. , 0. , 0.01, 0. , 0. , 0. , 0.01, 0.01]])
intercept_aux = jnp.array([0, 0])
intercept_stocks = jnp.array([0] * 4)
ode_sol = vmap(lambda y0, constants: odeint(dz_dt, y0, ts, K,
constants,
intercept_stocks,
params_aux_mat,
intercept_aux))(z_init, constants)
To call the function with vmap() just works, but check_grads gives the above mentioned error.
from jax.test_util import check_grads
check_grads(dz_dt, (z_init, ts, K, constants, intercept_stocks, params_aux_mat, intercept_aux), 1, eps=1e-2)
I’m not sure how check_grads
works but it seems that dz_dt
does not work on batch arrays z_init
, constants
. Do you mean to use vmap
there? I changed
intercept_aux = jnp.array([0., 0.])
intercept_stocks = jnp.array([0.] * 4)
and it seems to work with
target = lambda y0, constants: odeint(
dz_dt, y0, ts, K, constants,
intercept_stocks, params_aux_mat, intercept_aux).sum()
jax.grad(target)(z_init[0], constants[0]), jax.jacobian(jax.vmap(target))(z_init, constants)
Wow, it seems like indeed turning the zeros into floats rather than integers helped the code run! I’m not getting the correct coefficients back yet but that’s likely a different error in my code somewhere. Thanks!
Is it by any chance a familiar problem that sometimes the object returned by MCMC(NUTS(…),) .run() takes forever to load when using, for instance, .get_samples(), or print_summary()? Any hints as to what it means when this happens? Did my sampling fail?
I have this with the above model when I run NUTS but not when I run svi Minimize. In the latter case I just find incorrect coefficients.
By the way, you asked if I meant to use vmap() I’m a bit confused now as you recommended this to me earlier here:
Am I not using it correctly?
Never mind. In your reproducible code, you applied dz/dt for a batch of inputs. I think you need to vmap dz/dt there (in check grads line).
Thanks! I’ll just leave it as is.
Just to not let the other question die out:
Also, although I do get param estimates from svi / Minimize() (which are incorrect), when I then do guide.get_posterior(params)
to check the uncertainty I again get the error I mentioned earlier, i.e.:
TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
NUTS should not cause infinite loop. I guess your odeint does not stop. You might want to set mxsteps
like in ode example (I faced infinite loop in the past and submitted that flag to jax - hopefully it also resolves your case).
can’t apply forward-mode autodiff (jvp) to a custom_vjp function
I think minimize use forward-mode but ode has custom vjp - so this usage case is not supported? It would be nice to point out where the error happens, at least some tracebacks - so we can help better. (It is always tricky to identify the error based on the error message alone - especially when the error message does not come from our framework). For example, which lines in get_posterior
causes the issue might give us more concrete ideas - but it is best to have some reproducible code (by just removing all unnecessary stuffs - for example, you can just make dz/dt
behaves like an identity function or remove most of the latent samples…)
Hi fehiepsi! Thanks a lot for your input again. Unfortunately the issue with NUTS wasn’t solved by adding the mxsteps
option. I did figure out that the problem is not related to scan() as the infinite load persists when I reintroduce the for-loop. Any other suggestions?
Regarding the error message for minimize(), i’ll share the error message below. If that won’t do it, I’ll try another minimal example. I’ll just have to add a model function to my previous example in this topic.
Here’s the “short” error that excludes Jax’s internal frames (the longer version was too long):
FilteredStackTrace Traceback (most recent call last)
/var/folders/58/rcr3xzjn6bscyr8llj4t161m0000gn/T/ipykernel_3097/1815523977.py in <module>
1 #Predictive(model, guide, num_samples=2000)
----> 2 guide.quantiles(params, 0.5)
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/autoguide.py in quantiles(self, params, quantiles)
1110 def quantiles(self, params, quantiles):
-> 1111 transform = self.get_transform(params)
1112 quantiles = jnp.array(quantiles)[..., None]
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/autoguide.py in get_transform(self, params)
1082 loc = params["{}_loc".format(self.prefix)]
-> 1083 precision = hessian(loss_fn)(loc)
1084 scale_tril = cholesky_of_inverse(precision)
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/autoguide.py in loss_fn(z)
1079 params1["{}_loc".format(self.prefix)] = z
-> 1080 return self._loss_fn(params1)
1081
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/autoguide.py in loss_fn(params)
1062 return Trace_ELBO().loss(
-> 1063 random.PRNGKey(0), params, self.model, self, *args, **kwargs
1064 )
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/elbo.py in loss(self, rng_key, param_map, model, guide, *args, **kwargs)
53 return self.loss_with_mutable_state(
---> 54 rng_key, param_map, model, guide, *args, **kwargs
55 )["loss"]
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/elbo.py in loss_with_mutable_state(self, rng_key, param_map, model, guide, *args, **kwargs)
149 if self.num_particles == 1:
--> 150 elbo, mutable_state = single_particle_elbo(rng_key)
151 return {"loss": -elbo, "mutable_state": mutable_state}
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/elbo.py in single_particle_elbo(rng_key)
124 model_log_density, model_trace = log_density(
--> 125 seeded_model, args, kwargs, params
126 )
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/infer/util.py in log_density(model, model_args, model_kwargs, params)
52 model = substitute(model, data=params)
---> 53 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
54 log_joint = jnp.zeros(())
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
164 """
--> 165 self(*args, **kwargs)
166 return self.trace
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
86 with self:
---> 87 return self.fn(*args, **kwargs)
88
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
86 with self:
---> 87 return self.fn(*args, **kwargs)
88
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
86 with self:
---> 87 return self.fn(*args, **kwargs)
88
~/Anaconda/anaconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
86 with self:
---> 87 return self.fn(*args, **kwargs)
88
/var/folders/58/rcr3xzjn6bscyr8llj4t161m0000gn/T/ipykernel_3097/3862560390.py in model(data, prior_distribution, dir_mat, s, sample_vars, data_bl, aux_predictors_mat, intercept_aux, predictive)
29 intercept_aux, rtol=1e-6, atol=1e-5,
---> 30 mxstep=1000))(data_bl_stocks, data_bl_constants)
31 y_pred = jnp.concatenate(z_array)
/var/folders/58/rcr3xzjn6bscyr8llj4t161m0000gn/T/ipykernel_3097/3862560390.py in <lambda>(y0, constants)
29 intercept_aux, rtol=1e-6, atol=1e-5,
---> 30 mxstep=1000))(data_bl_stocks, data_bl_constants)
31 y_pred = jnp.concatenate(z_array)
~/Anaconda/anaconda3/lib/python3.7/site-packages/jax/experimental/ode.py in odeint(func, y0, t, rtol, atol, mxstep, *args)
172 converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
--> 173 return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *args, *consts)
174
~/Anaconda/anaconda3/lib/python3.7/site-packages/jax/experimental/ode.py in _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args)
178 func = ravel_first_arg(func, unravel)
--> 179 out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
180 return jax.vmap(unravel)(out)
~/Anaconda/anaconda3/lib/python3.7/site-packages/jax/experimental/ode.py in _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args)
217 def _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args):
--> 218 ys = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
219 return ys, (ys, ts, args)
FilteredStackTrace: TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
Oh I see, the error message is really helpful. Thanks! It says that hessian
(which combines forward-mode differentiation and reverse-mode differentiation: hessian(f)(x) = jacfwd(jacrev(f))(x)) does not work with custom_vjp (in odeint
). I guess we can use jacrev(jacrev(f))(x)
there. We can add a configuration in Laplace Approximation to control this behavior. Could you test it first? If it resolves the issue, then could you make a feature request for this.
Any other suggestions?
I couldn’t think of anything else on the top of my head. I guess I need some reproducible code for this one.
hi @fehiepsi, that sounds like a good solution!
The following works for me, while jacfwd() gives the mentioned error.
f = lambda y0 : odeint(dz_dt_loop, y0, ts, K,
constants[0],
intercept_stocks,
params_aux_mat,
intercept_aux, rtol=1e-6,
atol=1e-5, mxstep=1000)
jax.jacrev(f)(z_init[0])
A double application of jacrev() here gives me the following error but I’m guessing that’s unrelated so I opened an issue on the NumPyro github.
FilteredStackTrace: ValueError: vmap has mapped output but out_axes is None
Here’s the issue: [Feature request] hessian(f)(x) = jacrev(jacrev(f))(x)) · Issue #1195 · pyro-ppl/numpyro · GitHub.
I hope this is what you meant!
So even with doubling jacrev, there is still an issue. I wonder if you can make reproducible code so we can ask jax dev (it is hard for them to help without reproducible code)? I think that would be the easiest solution because this debugging thread has been quite long. For reproducible code, you can:
- remove unnecessary operators
- remove unnecessary latent variables
- create fake and small data
- remove unnecessary arguments for intermediate functions
- remove unnecessary intermediate functions
- make reproducible short, in under 20 lines of code
You can remove one by one as long as errors still happen.
OK thanks @fehiepsi, I’ve made some progress again. Firstly, it seems that the infinite loop issue with NUTS was resolved by turning progress_bar=True on. Strangely…
Then regarding the minimal example to reproduce the forward mode differentiation error. The last line should cause the error.
import jax
import jax.numpy as jnp
from jax import ops, random
import numpy as np
from jax.experimental.ode import odeint
import numpyro
import numpyro.distributions as dist
from numpyro.infer.autoguide import AutoLaplaceApproximation
from numpyro.infer import SVI, Trace_ELBO
def dy_dt(z, t, A):
return jnp.matmul(A, z)
def model(data, ts, ind_mat):
theta = numpyro.sample("theta", dist.Normal(loc=jnp.array([0.] * 2),
scale=jnp.array([1.0]) * 2),)
A = jnp.array(np.zeros([ind_mat.shape[0], ind_mat.shape[1]]))
A = ops.index_update(A, ind_mat, theta)
y_pred = odeint(dy_dt, fake_data[0], ts, A, rtol=1e-6, atol=1e-5, mxstep=1000)
sigma = numpyro.sample("sigma", dist.LogNormal(-1, 1).expand([2]))
numpyro.sample("y_pred", dist.Normal(y_pred[1:], sigma), obs=fake_data[1:])
def get_MAP(data, ts, ind_mat):
optimizer = numpyro.optim.Minimize(method="BFGS")
guide = AutoLaplaceApproximation(model)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
init_state = svi.init(random.PRNGKey(0), data, ts, ind_mat)
optimal_state, loss = svi.update(init_state, data, ts, ind_mat)
params = svi.get_params(optimal_state)
return params, guide, svi, optimal_state, loss
ts = jnp.array(list(np.linspace(0, 240, 81)))
y_init = jnp.array([-0.06, -0.01])
K = jnp.array([[ 0., 0.01],
[ 0.01, 0.]])
ind_mat = (K != 0)
fake_data = odeint(dy_dt, y_init, ts, K)
params, guide, svi, optimal_state, loss = get_MAP(fake_data, ts, ind_mat)
print(params)
print(guide.quantiles(params, 0.91))
infinite loop issue with NUTS was resolved by turning progress_bar=True
This is likely an XLA issue, which is sometimes very slow in some systems.
The last line should cause the error.
This seems to be fixed in this pr. You can test it with
!pip install -q git+https://github.com/pyro-ppl/numpyro.git@refs/pull/1196/head
guide = AutoLaplaceApproximation(model,
get_precision=lambda f: jax.jacobian(jax.jacobian(f)))
Thanks! Nice that it was already fixed.
Sorry for this extremely noob question but: the first line is meant to pull the fixed numpyro version, right? So I run that in terminal, and then in my model when I assign the guide I add the get_precision setting?
!pip install -q git+https://github.com/pyro-ppl/numpyro.git@refs/pull/1196/head
I’m asking because when I do run this line in terminal it installs a few requirements but does not allow me to use “get_precision”
It tells me:
TypeError: __init__() got an unexpected keyword argument 'get_precision'
That’s right, I tested it with colab but not sure if the installation command works with terminal. You might want to uninstall numpyro first and remove the -q option to get more info messages.
thanks @fehiepsi , I also got it to work with colab! excellent.
I guess even this long thread can come to an end . Btw, did you notice anything odd about my code, as I haven’t been able to find the parameters back that I create the toy data with (also if I do repeated updates, or use svi.run() instead.
I think you can use svi.run
. Your code looks great to me. FYI, when the pr (which is under reviewed) is merged, get_precision
name/pattern can be changed so pls follow the description in the docs.