Vmap or lax.span with index_update in jax

I am using `jax` within `numpyro` and want to find an efficient way (maybe using `lax.scan` or `vmap`) to update `t1` with the values of `weights_w1` with respect to the `mask` I have. I am using `jax`. The only code that works is using a for loop, which takes a long time.

Problem formulation

Explicitly, I have

1. a vector of weights
``````weights_w1 = jax.numpy.asarray([0.55312582, 0.20736367, -0.74858772, 0.82512423])
``````
1. which I want to map onto a matrix `t1`, which I have initialized by zeros
``````t1 = jax.numpy.asarray([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]])

``````
1. based on this `mask`, where
``````mask = jax.numpy.asarray([[1., 0.],
[1., 0.],
[0., 1.],
[0., 1.]])

``````
1. With the ideal output for `t1` being:
``````>> print(t1)
[[ 0.5531258   0.        ]
[ 0.20736367  0.        ]
[ 0.         -0.7485877 ]
[ 0.          0.8251242 ]]
``````

Working code with for loop

I can do this in a for loop (which takes a long time).

``````import jax
import numpy as np
from jax import vmap
import jax.numpy as jnp

def sparsify_weights(args):
pair_indices, a_full, w_init = args
i, j = pair_indices
a_full = jax.ops.index_update(a_full, jnp.index_exp[i, j], w_init[i])
return(a_full)

index_pairs = []

for p in index_pairs:
o_full = sparsify_weights([p, o_full, w_init])
return(o_full)

t1 = jnp.zeros((4, 2))
weights_w1 = jax.numpy.asarray([0.55312582, 0.20736367, -0.74858772, 0.82512423])
[1., 0.],
[0., 1.],
[0., 1.]])

t1 = prep_sparsify_weights(t1, weights_w1, mask, 1)
print(t1)
``````

Code attempt with `lax.scan`:

I have not gotten my `lax.scan` code to work. What are your thoughts?

``````import jax, sys
import numpy as np
from functools import partial
from jax import jit
import jax.numpy as jnp
from jax import lax

def paraUpdate(pair_indices, a_full, w_init):
i, j = pair_indices
a_full = jax.ops.index_update(a_full, jnp.index_exp[i, j], w_init[i])
return(a_full)

#@jax.jit
#@partial(jit, static_argnums=1)
def filter_jax2(o_full, w_init, jax_ranger):
return(lax.scan(paraUpdate, jax_ranger, o_full, w_init))

t1 = jnp.zeros((4, 2))
weights_w1 = jax.numpy.asarray([0.55312582, 0.20736367, -0.74858772, 0.82512423])
[1., 0.],
[0., 1.],
[0., 1.]])

ranger = []

t1 = filter_jax2(t1, weights_w1, jnp.asarray(ranger))
``````

what prevents you from doing one large `jax.ops.index_update` call?

Hm, I’m not sure how exactly to do that. Could you show me using the toy dataset and variable names above?

``````import jax.numpy as jnp

x = jnp.zeros((4, 2))
y = jnp.array([1.1, 2.2, 3.3, 4.4])
idx = ([0, 1, 2, 3], [0, 0, 1, 1])
x = x.at[idx].set(y)
print(x)
``````

outputs

``````[[1.1 0. ]
[2.2 0. ]
[0.  3.3]
[0.  4.4]]
``````

Ah, so this does not use `jax.ops.index_update` but `set`.

I was not familiar with `set`, and found this Github article highlighting (on 10/01/2021) that `jax.ops.index_update` is no longer the most efficient way to update indices in `jax`, and to use `set`.

Should you have other resources I would be excited to stay updated with the latest and best practices like the one you have shared!