Batching sparse BCOO objects

Hi all,

I want to take batches of a sparse document-term matrix in a jit compiled function. When i do:

from sklearn.feature_extraction.text import CountVectorizer
import scipy.sparse as sparse
import jax
import jax.numpy as jnp
from jax.experimental.sparse import BCOO

# some df1 with text data
cv = CountVectorizer(stop_words='english', min_df = 2)
counts = sparse.csr_matrix(cv.fit_transform(df1["Text"]), dtype = np.float32)

# creating sparse jax object
counts_jnp = BCOO.from_scipy_sparse(counts) 
num_documents = counts_jnp.shape[0]
batch_size = 1024

# batch function
@jit
def get_batch(rng, Y):
    D_batch = random.choice(rng, jnp.arange(num_documents), shape=(batch_size,))
    return Y.todense()[D_batch], D_batch  

Y_batch, D_batch = get_batch(random.PRNGKey(1), counts_jnp)

things are fine and extremely fast, when the number of documents are not too high ( e.g. num_documents == 100k). However, when i increase the number of documents to lets say 1m, the get_batch function is pretty slow, because i am densifying the whole Y matrix before indexing my relevant documents Y.todense()[D_batch].

When i want to return

Y[D_batch].todense()

instead, I get an OOM error. Which is curious because in the former example there are no memory issues.
Has someone an explanation?
Thanks.

Hi BPro2410,

I can’t immediately understand why you would OOM just from the source. Could you provide a small working example of this behavior? Then, I’ll look into what is going on.

Best, Ola

Hi Ola,

thanks for your reply. You find a minimal example below. I am running the code on AWS ml.g5.2xlarge instance. When executing the get_batch_2() function, I get the error:

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 148330749608 bytes.

From my understanding the get_batch_2() function should be much more efficient than get_batch_1().

Best,
Ben

import numpy as np
from scipy.sparse import csr_matrix
from jax.experimental.sparse import BCOO
from jax import jit, random
import jax
import jax.numpy as jnp

# Number of documents and vocabulary size
num_docs = 100000
vocab_size = 25000

# Simulating a sparse document-term matrix
# Assuming on average each document contains 50 terms (randomly selected from vocabulary)
density = 50 / vocab_size

# Generate random sparse matrix
rows = np.random.randint(0, num_docs, int(num_docs * vocab_size * density))
cols = np.random.randint(0, vocab_size, int(num_docs * vocab_size * density))
data = np.random.randint(1, 5, int(num_docs * vocab_size * density))  # Random term frequencies between 1 and 4

# Create the csr_matrix
sparse_matrix = csr_matrix((data, (rows, cols)), shape=(num_docs, vocab_size), dtype = np.float32)


# Transfer to jax BCOO object
counts_jnp = BCOO.from_scipy_sparse(sparse_matrix) 
num_documents = counts_jnp.shape[0]


@jit
def get_batch_1(rng, Y):
    D_batch = random.choice(rng, jnp.arange(num_documents), shape=(1024,))
    return Y.todense()[D_batch], D_batch 

@jit
def get_batch_2(rng, Y):
    D_batch = random.choice(rng, jnp.arange(num_documents), shape=(1024,))
    return Y[D_batch].todense(), D_batch 


# -- WORKS --
Y_batch, D_batch = get_batch_1(random.PRNGKey(0), counts_jnp)

# -- FAILS - OOM --
Y_batch, D_batch = get_batch_2(random.PRNGKey(0), counts_jnp)

Thanks Ben. I have some time to look at it in the morning.

Hi Ben,

I believe Jax sparse-sparse matrix multiplication memory usage · jax-ml/jax · Discussion #17251 · GitHub is related to your problem. I suggest you open a discussion at jax-ml/jax discussions with the example you made. They will probably be able to suggest how to handle the problem faster than I can.

Best, Ola

1 Like

Hi Ola,

thank you for your response. I believe the link you provided demonstrates that there are some issues with utilizing the experimental sparse BCOO method. However, I will also post this inquiry in the Jax forum. I greatly appreciate your assistance.

Thanks!

1 Like