I am using `jax`

within `numpyro`

and want to find an efficient way (maybe using `lax.scan`

or `vmap`

) to update `t1`

with the values of `weights_w1`

with respect to the `mask`

. The only code that works is using a for loop, which takes a *long time*.

## Problem formulation

Explicitly, I have

- a vector of weights

```
weights_w1 = jax.numpy.asarray([0.55312582, 0.20736367, -0.74858772, 0.82512423])
```

- which I want to map onto a matrix
`t1`

, which I have initialized by zeros

```
t1 = jax.numpy.asarray([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]])
```

- based on this
`mask`

, where

```
mask = jax.numpy.asarray([[1., 0.],
[1., 0.],
[0., 1.],
[0., 1.]])
```

- With the ideal output for
`t1`

being:

```
>> print(t1)
[[ 0.5531258 0. ]
[ 0.20736367 0. ]
[ 0. -0.7485877 ]
[ 0. 0.8251242 ]]
```

## Working code with for loop

I can do this in a for loop (which takes a *long time*).

```
import jax
import numpy as np
from jax import vmap
import jax.numpy as jnp
def sparsify_weights(args):
pair_indices, a_full, w_init = args
i, j = pair_indices
a_full = jax.ops.index_update(a_full, jnp.index_exp[i, j], w_init[i])
return(a_full)
def prep_sparsify_weights(o_full, w_init, mask, num_parallel):
index_pairs = []
[index_pairs.append((i, j)) for i,j in zip(np.nonzero(mask)[0], np.nonzero(mask)[1])]
for p in index_pairs:
o_full = sparsify_weights([p, o_full, w_init])
return(o_full)
t1 = jnp.zeros((4, 2))
weights_w1 = jax.numpy.asarray([0.55312582, 0.20736367, -0.74858772, 0.82512423])
mask = jax.numpy.asarray([[1., 0.],
[1., 0.],
[0., 1.],
[0., 1.]])
t1 = prep_sparsify_weights(t1, weights_w1, mask, 1)
print(t1)
```

##
Code attempt with `lax.scan`

:

I have not gotten my `lax.scan`

code to work. What are your thoughts?

```
import jax, sys
import numpy as np
from functools import partial
from jax import jit
import jax.numpy as jnp
from jax import lax
def paraUpdate(pair_indices, a_full, w_init):
i, j = pair_indices
a_full = jax.ops.index_update(a_full, jnp.index_exp[i, j], w_init[i])
return(a_full)
#@jax.jit
#@partial(jit, static_argnums=1)
def filter_jax2(o_full, w_init, jax_ranger):
return(lax.scan(paraUpdate, jax_ranger, o_full, w_init))
t1 = jnp.zeros((4, 2))
weights_w1 = jax.numpy.asarray([0.55312582, 0.20736367, -0.74858772, 0.82512423])
mask = jax.numpy.asarray([[1., 0.],
[1., 0.],
[0., 1.],
[0., 1.]])
ranger = []
[ranger.append((i, j)) for i,j in zip(np.nonzero(mask)[0], np.nonzero(mask)[1])]
t1 = filter_jax2(t1, weights_w1, jnp.asarray(ranger))
```