Thanks @fehiepsi
I’m trying SVI with an AutoDelta guide following Example: Zero-Inflated Poisson regression model — NumPyro documentation but I’m getting an error about padding:
JaxStackTraceBeforeTransformation: ValueError: length of padding_config must equal the number of axes of operand, got padding_config ((16, 0, 0),) for operand shape (1, 1)
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
ValueError: length of padding_config must equal the number of axes of operand, got padding_config ((16, 0, 0),) for operand shape (1, 1)
I wonder if this is because my numpyro model function involves calling a @shard_map decorated function that splits the computation across multiple CPU cores (or multiple GPUs). Note that Predictive and NUTS both run fine without these padding errors. I noticed another short post about this sharding issue where you suggested padding was needed but I couldn’t follow – what is different about Predictive/NUTS/mcmc vs. svi that causes such a padding error?
For convenience, I’d like to use numpyro’s own SVI for MAP estimation but could I also just use optax.adam manually in combination with the value and grad of numpyro.infer.util.log_density?