I have a model that draws from a Multinomial distribution and I’m getting the following error when I try to sample from it:
XlaRuntimeError: UNIMPLEMENTED: Scatter operations with more than 2147483647 scatter indices are not supported.
This error only happens when I draw from the full sample and it seems to be related to memory use.
I can reproduce the error with this code by increasing the maximum possible count in the multinomial draw:
import numpy as np
import numpyro
import numpyro.distributions as dist
n_units = 3000
n_choices = 8
max_count = 3_000_000 # get above error
# max_count = 1_000 # works
logits = np.random.random(size=(n_units, n_choices))
counts = np.random.randint(0, max_count, size=logits.shape[:-1])
print(logits.shape, counts.shape)
print(logits.dtype, counts.dtype)
with numpyro.handlers.seed(rng_seed=1):
y = numpyro.sample("y", dist.Multinomial(counts, logits=logits))
I’d like to understand better what’s causing this behavior and if there is a way to avoid it.