I am using jax within numpyro and want to find an efficient way (maybe using lax.scan or vmap) to update t1 with the values of weights_w1 with respect to the mask I have. I am using jax. The only code that works is using a for loop, which takes a long time.
Problem formulation
Explicitly, I have
- a vector of weights
weights_w1 = jax.numpy.asarray([0.55312582, 0.20736367, -0.74858772, 0.82512423])
- which I want to map onto a matrix
t1, which I have initialized by zeros
t1 = jax.numpy.asarray([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]])
- based on this
mask, where
mask = jax.numpy.asarray([[1., 0.],
[1., 0.],
[0., 1.],
[0., 1.]])
- With the ideal output for
t1being:
>> print(t1)
[[ 0.5531258 0. ]
[ 0.20736367 0. ]
[ 0. -0.7485877 ]
[ 0. 0.8251242 ]]
Working code with for loop
I can do this in a for loop (which takes a long time).
import jax
import numpy as np
from jax import vmap
import jax.numpy as jnp
def sparsify_weights(args):
pair_indices, a_full, w_init = args
i, j = pair_indices
a_full = jax.ops.index_update(a_full, jnp.index_exp[i, j], w_init[i])
return(a_full)
def prep_sparsify_weights(o_full, w_init, mask, num_parallel):
index_pairs = []
[index_pairs.append((i, j)) for i,j in zip(np.nonzero(mask)[0], np.nonzero(mask)[1])]
for p in index_pairs:
o_full = sparsify_weights([p, o_full, w_init])
return(o_full)
t1 = jnp.zeros((4, 2))
weights_w1 = jax.numpy.asarray([0.55312582, 0.20736367, -0.74858772, 0.82512423])
mask = jax.numpy.asarray([[1., 0.],
[1., 0.],
[0., 1.],
[0., 1.]])
t1 = prep_sparsify_weights(t1, weights_w1, mask, 1)
print(t1)
Code attempt with lax.scan:
I have not gotten my lax.scan code to work. What are your thoughts?
import jax, sys
import numpy as np
from functools import partial
from jax import jit
import jax.numpy as jnp
from jax import lax
def paraUpdate(pair_indices, a_full, w_init):
i, j = pair_indices
a_full = jax.ops.index_update(a_full, jnp.index_exp[i, j], w_init[i])
return(a_full)
#@jax.jit
#@partial(jit, static_argnums=1)
def filter_jax2(o_full, w_init, jax_ranger):
return(lax.scan(paraUpdate, jax_ranger, o_full, w_init))
t1 = jnp.zeros((4, 2))
weights_w1 = jax.numpy.asarray([0.55312582, 0.20736367, -0.74858772, 0.82512423])
mask = jax.numpy.asarray([[1., 0.],
[1., 0.],
[0., 1.],
[0., 1.]])
ranger = []
[ranger.append((i, j)) for i,j in zip(np.nonzero(mask)[0], np.nonzero(mask)[1])]
t1 = filter_jax2(t1, weights_w1, jnp.asarray(ranger))