Vmap or lax.span with index_update in jax

I am using jax within numpyro and want to find an efficient way (maybe using lax.scan or vmap) to update t1 with the values of weights_w1 with respect to the mask I have. I am using jax. The only code that works is using a for loop, which takes a long time.

Problem formulation

Explicitly, I have

  1. a vector of weights
weights_w1 = jax.numpy.asarray([0.55312582, 0.20736367, -0.74858772, 0.82512423])
  1. which I want to map onto a matrix t1, which I have initialized by zeros
t1 = jax.numpy.asarray([[0., 0.],
                        [0., 0.],
                        [0., 0.],
                        [0., 0.]])

  1. based on this mask, where
mask = jax.numpy.asarray([[1., 0.],
                         [1., 0.],
                         [0., 1.],
                         [0., 1.]])

  1. With the ideal output for t1 being:
>> print(t1)
[[ 0.5531258   0.        ]
 [ 0.20736367  0.        ]
 [ 0.         -0.7485877 ]
 [ 0.          0.8251242 ]] 

Working code with for loop

I can do this in a for loop (which takes a long time).

import jax
import numpy as np
from jax import vmap
import jax.numpy as jnp

def sparsify_weights(args):
    pair_indices, a_full, w_init = args
    i, j = pair_indices
    a_full = jax.ops.index_update(a_full, jnp.index_exp[i, j], w_init[i])
    return(a_full)

def prep_sparsify_weights(o_full, w_init, mask, num_parallel):
    index_pairs = []
    [index_pairs.append((i, j)) for i,j in zip(np.nonzero(mask)[0], np.nonzero(mask)[1])]

    for p in index_pairs:
        o_full = sparsify_weights([p, o_full, w_init])
    return(o_full)

t1 = jnp.zeros((4, 2))
weights_w1 = jax.numpy.asarray([0.55312582, 0.20736367, -0.74858772, 0.82512423])
mask = jax.numpy.asarray([[1., 0.],
             [1., 0.],
             [0., 1.],
             [0., 1.]])

t1 = prep_sparsify_weights(t1, weights_w1, mask, 1)
print(t1)

Code attempt with lax.scan:

I have not gotten my lax.scan code to work. What are your thoughts?

import jax, sys
import numpy as np
from functools import partial
from jax import jit
import jax.numpy as jnp
from jax import lax

def paraUpdate(pair_indices, a_full, w_init):
    i, j = pair_indices
    a_full = jax.ops.index_update(a_full, jnp.index_exp[i, j], w_init[i])
    return(a_full)

#@jax.jit
#@partial(jit, static_argnums=1)
def filter_jax2(o_full, w_init, jax_ranger):
    return(lax.scan(paraUpdate, jax_ranger, o_full, w_init))

t1 = jnp.zeros((4, 2))
weights_w1 = jax.numpy.asarray([0.55312582, 0.20736367, -0.74858772, 0.82512423])
mask = jax.numpy.asarray([[1., 0.],
                          [1., 0.],
                          [0., 1.],
                          [0., 1.]])

ranger = []
[ranger.append((i, j)) for i,j in zip(np.nonzero(mask)[0], np.nonzero(mask)[1])]
    
t1 = filter_jax2(t1, weights_w1, jnp.asarray(ranger))

what prevents you from doing one large jax.ops.index_update call?

Hm, I’m not sure how exactly to do that. Could you show me using the toy dataset and variable names above?

import jax.numpy as jnp

x = jnp.zeros((4, 2))
y = jnp.array([1.1, 2.2, 3.3, 4.4])
idx = ([0, 1, 2, 3], [0, 0, 1, 1])
x = x.at[idx].set(y)
print(x)

outputs

[[1.1 0. ]
 [2.2 0. ]
 [0.  3.3]
 [0.  4.4]]

Ah, so this does not use jax.ops.index_update but set.

I was not familiar with set, and found this Github article highlighting (on 10/01/2021) that jax.ops.index_update is no longer the most efficient way to update indices in jax, and to use set.

Should you have other resources I would be excited to stay updated with the latest and best practices like the one you have shared!