Curious if anyone has used jax.experimental.sparse when passing data to a numpyro model? Has anyone had success from a memory and speed perspective when using the BCOO matrices and sparsifying all of the jnp functions?
I see. This is helpful but this is not using jax it seems. I want to take advantage of jax compilation still. Was curious about the jax sparse implementation and if anyone recommends it
The sample method does not leverage the sparsity (though I think it should be doable) but the log prob involves sparse matmul operators. The distribution should also be jit-compiled.
so you’re saying i can jit compile with scipy sparse matrices?
My experience is limitted. I used jax sparse on some graph neural network stuff. It works fine IIRC.
If you look at the implementation, you will see how scipy matrix is converted into a jax scipy matrix. Then the rest of the implementation should be jit compiled.
hmm I’m not seeing that in the code. I see this
# TODO: look into future jax sparse csr functionality and other developments
self.adj_matrix = _to_sparse(adj_matrix)
It is something like
adj_matrix = BCOO.from_scipy_sparse(adj_matrix)
... adj_matrix @ phi[..., jnp.newaxis]
There are two scenarios depending on your application:
- if your sparse matrix is the input of your jitted program, you will need to convert to jax sparse outside your program
- if your sparse matrix is a global constant, you can convert to jax sparse inside your program
thanks so much for the help. so i am doing approach 1. So then i need to sparsify all of the jnp functions in my model too, correct? seems annoying but perhaps it is worth it
that sounds right to me.