I have a model that draws from a Multinomial distribution and I’m getting the following error when I try to sample from it:
XlaRuntimeError: UNIMPLEMENTED: Scatter operations with more than 2147483647 scatter indices are not supported.
This error only happens when I draw from the full sample and it seems to be related to memory use.
I can reproduce the error with this code by increasing the maximum possible count in the multinomial draw:
import numpy as np import numpyro import numpyro.distributions as dist n_units = 3000 n_choices = 8 max_count = 3_000_000 # get above error # max_count = 1_000 # works logits = np.random.random(size=(n_units, n_choices)) counts = np.random.randint(0, max_count, size=logits.shape[:-1]) print(logits.shape, counts.shape) print(logits.dtype, counts.dtype) with numpyro.handlers.seed(rng_seed=1): y = numpyro.sample("y", dist.Multinomial(counts, logits=logits))
I’d like to understand better what’s causing this behavior and if there is a way to avoid it.