I am using jax
within numpyro
and want to find an efficient way (maybe using lax.scan
or vmap
) to update t1
with the values of weights_w1
with respect to the mask
I have. I am using jax
. The only code that works is using a for loop, which takes a long time.
Problem formulation
Explicitly, I have
- a vector of weights
weights_w1 = jax.numpy.asarray([0.55312582, 0.20736367, -0.74858772, 0.82512423])
- which I want to map onto a matrix
t1
, which I have initialized by zeros
t1 = jax.numpy.asarray([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]])
- based on this
mask
, where
mask = jax.numpy.asarray([[1., 0.],
[1., 0.],
[0., 1.],
[0., 1.]])
- With the ideal output for
t1
being:
>> print(t1)
[[ 0.5531258 0. ]
[ 0.20736367 0. ]
[ 0. -0.7485877 ]
[ 0. 0.8251242 ]]
Working code with for loop
I can do this in a for loop (which takes a long time).
import jax
import numpy as np
from jax import vmap
import jax.numpy as jnp
def sparsify_weights(args):
pair_indices, a_full, w_init = args
i, j = pair_indices
a_full = jax.ops.index_update(a_full, jnp.index_exp[i, j], w_init[i])
return(a_full)
def prep_sparsify_weights(o_full, w_init, mask, num_parallel):
index_pairs = []
[index_pairs.append((i, j)) for i,j in zip(np.nonzero(mask)[0], np.nonzero(mask)[1])]
for p in index_pairs:
o_full = sparsify_weights([p, o_full, w_init])
return(o_full)
t1 = jnp.zeros((4, 2))
weights_w1 = jax.numpy.asarray([0.55312582, 0.20736367, -0.74858772, 0.82512423])
mask = jax.numpy.asarray([[1., 0.],
[1., 0.],
[0., 1.],
[0., 1.]])
t1 = prep_sparsify_weights(t1, weights_w1, mask, 1)
print(t1)
Code attempt with lax.scan
:
I have not gotten my lax.scan
code to work. What are your thoughts?
import jax, sys
import numpy as np
from functools import partial
from jax import jit
import jax.numpy as jnp
from jax import lax
def paraUpdate(pair_indices, a_full, w_init):
i, j = pair_indices
a_full = jax.ops.index_update(a_full, jnp.index_exp[i, j], w_init[i])
return(a_full)
#@jax.jit
#@partial(jit, static_argnums=1)
def filter_jax2(o_full, w_init, jax_ranger):
return(lax.scan(paraUpdate, jax_ranger, o_full, w_init))
t1 = jnp.zeros((4, 2))
weights_w1 = jax.numpy.asarray([0.55312582, 0.20736367, -0.74858772, 0.82512423])
mask = jax.numpy.asarray([[1., 0.],
[1., 0.],
[0., 1.],
[0., 1.]])
ranger = []
[ranger.append((i, j)) for i,j in zip(np.nonzero(mask)[0], np.nonzero(mask)[1])]
t1 = filter_jax2(t1, weights_w1, jnp.asarray(ranger))