Jax.experimental.sparse for better memory/speed performance?

Curious if anyone has used jax.experimental.sparse when passing data to a numpyro model? Has anyone had success from a memory and speed perspective when using the BCOO matrices and sparsifying all of the jnp functions?

Hi @bball369, you can check out the CAR distribution as an example.

I see. This is helpful but this is not using jax it seems. I want to take advantage of jax compilation still. Was curious about the jax sparse implementation and if anyone recommends it

The sample method does not leverage the sparsity (though I think it should be doable) but the log prob involves sparse matmul operators. The distribution should also be jit-compiled.

so you’re saying i can jit compile with scipy sparse matrices?

My experience is limitted. I used jax sparse on some graph neural network stuff. It works fine IIRC.

If you look at the implementation, you will see how scipy matrix is converted into a jax scipy matrix. Then the rest of the implementation should be jit compiled.

hmm I’m not seeing that in the code. I see this

            # TODO: look into future jax sparse csr functionality and other developments
            self.adj_matrix = _to_sparse(adj_matrix)

It is something like

adj_matrix = BCOO.from_scipy_sparse(adj_matrix)
... adj_matrix @ phi[..., jnp.newaxis]

There are two scenarios depending on your application:

  • if your sparse matrix is the input of your jitted program, you will need to convert to jax sparse outside your program
  • if your sparse matrix is a global constant, you can convert to jax sparse inside your program

thanks so much for the help. so i am doing approach 1. So then i need to sparsify all of the jnp functions in my model too, correct? seems annoying but perhaps it is worth it

that sounds right to me.

1 Like