Performace of repeated matrix multiplication

Hi all!

I’m building a pharmacokinetic model that’s basically a discretized first-order ODE where c_t = K c_{t-1}.

It works great, but I can’t seem to get the program to 1) meaningfully use the GPU or 2) use more than ~1.5 CPUs. I have to assume the issue is some kind of blocking/non-parallelizable part, but I can’t figure out what that would be.

A stripped down version of the model is here. (Input: Y.npy, M.npy.)

I have already tried:

  • Making it explicitly an ODE. (This made it way slower.)
  • Using tensorboard to try to figure out what was slow and it was super vague, just saying it was sleeping or blocking on some kind of I/O. (I unfortunately lost the results of that.)
  • Getting rid of some of the data I’m storing (like the

Does anyone have any recommendations?

Cheers!

how long is this for loop?

for i in range(len(k01))

Thanks for the careful review!

It’s the same as N. I think I left it like that because I wasn’t sure how to get plate to handle it correctly.

At the moment it’s no more than 10, but I’d eventually like it to be hundreds.

Update:

I replaced:

    K = numpyro.deterministic('K', jnp.stack([
        jnp.asarray(
            [[  1-k01[i], k01[i]*(Vd[0,i]/Vd[1,i]),                        0],
             [         0,            1-k12[i]-u[i], k12[i]*(Vd[1,i]/Vd[2,i])],
             [         0, k21[i]*(Vd[2,i]/Vd[1,i]),                 1-k21[i]]])
        for i in range(len(k01))
    ]))

with

    zeros = jnp.zeros((N,))
    K = jnp.swapaxes(jnp.asarray(
            [[  1-k01, k01*(Vd[0]/Vd[1]),             zeros],
             [  zeros,           1-k12-u, k12*(Vd[1]/Vd[2])],
             [  zeros, k21*(Vd[2]/Vd[1]),              1-k21]]
    ).T, 1, 2)

without much performance change (with N=4, 1:04 vs 1:06 runtime and, with N=8, 6:08 vs 6:30).

yeah that change probably won’t matter much until the loop gets longer.

are you using enable_x64? that might conceivably help.

not sure how to make your code faster. presumably the scan takes up most of your compute. if you’re going to optimize anything that would probably be the place to look

yeah that change probably won’t matter much until the loop gets longer.

Word. Well, it has some improvement and the loop will be definitely getting longer so I appreciate the advice nonetheless!

are you using enable_x64? that might conceivably help.

So, I’m actually seeing a pretty substantial performance hit when I do this. Both versions (with and without the vectorized construction of K) go from runtimes of ~6 minutes to ~11-14 minutes (for N=10).

presumably the scan takes up most of your compute. if you’re going to optimize anything that would probably be the place to look

Yeah, I hear that.

So it’s not such much that the scan is kinda slow but what’s really bothering me is that that this code parallelizes really poorly. When I run it with N = 10, CPU usage tops out at like ~170% on a 12 core machine. (Predictably, when I try to put it on a GPU nvidia-smi reports like <10% usage.)

If it’s true that scan is where all the compute is, I’d like to parallelize that across each n. I thought the best way would be with the einsum (since the einsum should see that the first index is independent and vectorize smartly?) but maybe I should be trying to vmap the scan call across Ns?

At kind of a meta level, do you have any advice for what to try/how to investigate more deeply? I have been kind of doing “coding Monte Carlo” just changing things sort of at random keeping them if they work better, which is pretty inefficient. :slight_smile:

how the xla compiler works is pretty mysterious to me so it’s hard to give specific advice. in general a lot of numpyro performance questions really boil down to jax performance questions.

didn’t look at your model in detail but maybe you can rewrite things to compute powers of K in large parallelizable tensor ops along the lines of this code. if so you might actually start seeing good GPU utilization and therefore some speed-up.

So, I discovered something interesting, which is that replacing

    K = jnp.swapaxes(jnp.asarray(
            [[  1-k01, k01*(Vd[0]/Vd[1]),             zeros],
             [  zeros,           1-k12-u, k12*(Vd[1]/Vd[2])],
             [  zeros, k21*(Vd[2]/Vd[1]),              1-k21]]
    ).T, 1, 2)

with

        K = jnp.swapaxes(jnp.asarray(
            [[ 1-k01,     k01,     z],
             [     z, 1-k12-u,   k12],
             [     z,     k21, 1-k21]]
        ).T, 1, 2)

and then dividing the Vd parameters out at the end (i.e. after the scan invocation) gives a decent improvement in parallelization and performance. (Because all the rates are first-order, it doesn’t matter what the concentration is, just the mass in each compartment.)