Adam optimizer before NUTS?

Thanks @fehiepsi

I’m trying SVI with an AutoDelta guide following Example: Zero-Inflated Poisson regression model — NumPyro documentation but I’m getting an error about padding:

JaxStackTraceBeforeTransformation: ValueError: length of padding_config must equal the number of axes of operand, got padding_config ((16, 0, 0),) for operand shape (1, 1)

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

ValueError: length of padding_config must equal the number of axes of operand, got padding_config ((16, 0, 0),) for operand shape (1, 1)

I wonder if this is because my numpyro model function involves calling a @shard_map decorated function that splits the computation across multiple CPU cores (or multiple GPUs). Note that Predictive and NUTS both run fine without these padding errors. I noticed another short post about this sharding issue where you suggested padding was needed but I couldn’t follow – what is different about Predictive/NUTS/mcmc vs. svi that causes such a padding error?

For convenience, I’d like to use numpyro’s own SVI for MAP estimation but could I also just use optax.adam manually in combination with the value and grad of numpyro.infer.util.log_density?