Making a for-loop more efficient

So even with doubling jacrev, there is still an issue. I wonder if you can make reproducible code so we can ask jax dev (it is hard for them to help without reproducible code)? I think that would be the easiest solution because this debugging thread has been quite long. For reproducible code, you can:

  • remove unnecessary operators
  • remove unnecessary latent variables
  • create fake and small data
  • remove unnecessary arguments for intermediate functions
  • remove unnecessary intermediate functions
  • make reproducible short, in under 20 lines of code
    You can remove one by one as long as errors still happen.

OK thanks @fehiepsi, I’ve made some progress again. Firstly, it seems that the infinite loop issue with NUTS was resolved by turning progress_bar=True on. Strangely…

Then regarding the minimal example to reproduce the forward mode differentiation error. The last line should cause the error.

import jax
import jax.numpy as jnp
from jax import ops, random
import numpy as np
from jax.experimental.ode import odeint
import numpyro
import numpyro.distributions as dist
from numpyro.infer.autoguide import AutoLaplaceApproximation   
from numpyro.infer import SVI, Trace_ELBO

def dy_dt(z, t, A):
    return jnp.matmul(A, z) 

def model(data, ts, ind_mat):
    theta = numpyro.sample("theta", dist.Normal(loc=jnp.array([0.] * 2),
                                                scale=jnp.array([1.0]) * 2),)
    A = jnp.array(np.zeros([ind_mat.shape[0], ind_mat.shape[1]]))
    A = ops.index_update(A, ind_mat, theta)
    y_pred = odeint(dy_dt, fake_data[0], ts, A, rtol=1e-6, atol=1e-5, mxstep=1000) 
    sigma = numpyro.sample("sigma", dist.LogNormal(-1, 1).expand([2]))
    numpyro.sample("y_pred", dist.Normal(y_pred[1:], sigma), obs=fake_data[1:]) 

def get_MAP(data, ts, ind_mat):
    optimizer = numpyro.optim.Minimize(method="BFGS") 
    guide = AutoLaplaceApproximation(model)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    init_state = svi.init(random.PRNGKey(0), data, ts, ind_mat)
    optimal_state, loss = svi.update(init_state, data, ts, ind_mat)
    params = svi.get_params(optimal_state) 
    return params, guide, svi, optimal_state, loss 

ts = jnp.array(list(np.linspace(0, 240, 81)))
y_init = jnp.array([-0.06, -0.01])
K = jnp.array([[ 0., 0.01],
               [ 0.01,  0.]]) 
ind_mat = (K != 0)
fake_data = odeint(dy_dt, y_init, ts, K) 

params, guide, svi, optimal_state, loss = get_MAP(fake_data, ts, ind_mat)
print(params)
print(guide.quantiles(params, 0.91))

infinite loop issue with NUTS was resolved by turning progress_bar=True

This is likely an XLA issue, which is sometimes very slow in some systems.

The last line should cause the error.

This seems to be fixed in this pr. You can test it with

!pip install -q git+https://github.com/pyro-ppl/numpyro.git@refs/pull/1196/head
guide = AutoLaplaceApproximation(model,
    get_precision=lambda f: jax.jacobian(jax.jacobian(f)))
2 Likes

Thanks! Nice that it was already fixed.

Sorry for this extremely noob question but: the first line is meant to pull the fixed numpyro version, right? So I run that in terminal, and then in my model when I assign the guide I add the get_precision setting?

!pip install -q git+https://github.com/pyro-ppl/numpyro.git@refs/pull/1196/head

I’m asking because when I do run this line in terminal it installs a few requirements but does not allow me to use “get_precision”

It tells me:
TypeError: __init__() got an unexpected keyword argument 'get_precision'

That’s right, I tested it with colab but not sure if the installation command works with terminal. You might want to uninstall numpyro first and remove the -q option to get more info messages.

thanks @fehiepsi , I also got it to work with colab! :slight_smile: excellent.

I guess even this long thread can come to an end :stuck_out_tongue: . Btw, did you notice anything odd about my code, as I haven’t been able to find the parameters back that I create the toy data with (also if I do repeated updates, or use svi.run() instead.

I think you can use svi.run. Your code looks great to me. FYI, when the pr (which is under reviewed) is merged, get_precision name/pattern can be changed so pls follow the description in the docs.

1 Like

Thanks a lot! Very helpful