Sampling from Multinomial gives `XlaRuntimeError: UNIMPLEMENTED: Scatter operations with more than...`

I have a model that draws from a Multinomial distribution and I’m getting the following error when I try to sample from it:

XlaRuntimeError: UNIMPLEMENTED: Scatter operations with more than 2147483647 scatter indices are not supported.

This error only happens when I draw from the full sample and it seems to be related to memory use.

I can reproduce the error with this code by increasing the maximum possible count in the multinomial draw:

import numpy as np
import numpyro
import numpyro.distributions as dist

n_units = 3000
n_choices = 8

max_count = 3_000_000    # get above error
# max_count = 1_000      # works

logits = np.random.random(size=(n_units, n_choices))
counts = np.random.randint(0, max_count, size=logits.shape[:-1])

print(logits.shape, counts.shape)
print(logits.dtype, counts.dtype)

with numpyro.handlers.seed(rng_seed=1):
  y = numpyro.sample("y", dist.Multinomial(counts, logits=logits))

I’d like to understand better what’s causing this behavior and if there is a way to avoid it.

The error says that xla does not support scattering large indices. I think Multinomial uses some scatter operator. Could you chase it down?