Hi all,
I am using a rather straight forward SVI optimization, for a rather simple model. This is the training code:
guide = AutoNormal(model)
optimizer = numpyro.optim.Adam(step_size=0.001)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 15000, covariates, obs=obs)
This works perfectly fine for a small data set (sample count, the first dimension of the “covariates” and the “obs”), but once I increase it beyond a certain bar, I get this JAX error:
TracerIntegerConversionError: The index() method was called on the JAX Tracer object Traced<ShapedArray(int32, weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
See JAX Errors — JAX documentation
Has anyone came across this?
Best wishes,
Eyal.