JAX Error when increasing dataset size

Hi all,

I am using a rather straight forward SVI optimization, for a rather simple model. This is the training code:

guide = AutoNormal(model)
optimizer = numpyro.optim.Adam(step_size=0.001)

svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 15000, covariates, obs=obs)

This works perfectly fine for a small data set (sample count, the first dimension of the “covariates” and the “obs”), but once I increase it beyond a certain bar, I get this JAX error:

TracerIntegerConversionError: The index() method was called on the JAX Tracer object Traced<ShapedArray(int32, weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
See JAX Errors — JAX documentation

Has anyone came across this?

Best wishes,
Eyal.

Could you provide reproducible code?