I’m trying to infer the parameters of a non-linear ODE system. Would using a gradient descent optimizer like Adam (eg from optax) to initialize the guess starting point for NUTS be useful? Is something like this already implemented in numpyro?
I’m finding that the time to convergence for my NUTS inference is very sensitive to how small my uncertainties are that go into my Gaussian likelihood. I suspect that, when making the uncertainties small, the loss landscape becomes very steep so NUTS wants to take many small steps (depending on whether numpyro.infer.initial_to_sample initialized in a flat or steep part). Perhaps I can bypass this by using Adam first (with an exponentially decaying initially large learning rate) and then initialize NUTS where Adam stopped.
Any other strategies someone can suggest? I’ve also tried decreasing target_accept=0.5 and max_tree_depth=8.
Are you asking for how to set the initial starting point? If so you can use init_to_value strategy like in Example: MCMC Methods for Tall Data — NumPyro documentation
Thanks @fehiepsi that’s very interesting but not my current use case. I think my loss landscape is very wiggly but also steep so NUTS is taking very small steps and stuck near the init_value guess. I was instead thinking something like run SVI (or adam) first to get a MAP estimate of parameters, and then use the result of that to initialize the starting point for NUTS at the MAP from SVI/adam, kind of like here:
is this sequential 2-stage SVI/adam then NUTS usually a good strategy for complicated posteriors? Is it common to do?
It is a good strategy. As in the example in my last comment, if we dont initialize at the MAP estimate, NUTS will get trouble to mix. Not sure why you said it is not relevant.
Another approach is to neutralize the posterior geometry, like in Example: Neural Transport — NumPyro documentation
1 Like
Thanks @fehiepsi
I’m trying SVI with an AutoDelta guide following Example: Zero-Inflated Poisson regression model — NumPyro documentation but I’m getting an error about padding:
JaxStackTraceBeforeTransformation: ValueError: length of padding_config must equal the number of axes of operand, got padding_config ((16, 0, 0),) for operand shape (1, 1)
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
ValueError: length of padding_config must equal the number of axes of operand, got padding_config ((16, 0, 0),) for operand shape (1, 1)
I wonder if this is because my numpyro model function involves calling a @shard_map decorated function that splits the computation across multiple CPU cores (or multiple GPUs). Note that Predictive and NUTS both run fine without these padding errors. I noticed another short post about this sharding issue where you suggested padding was needed but I couldn’t follow – what is different about Predictive/NUTS/mcmc vs. svi that causes such a padding error?
For convenience, I’d like to use numpyro’s own SVI for MAP estimation but could I also just use optax.adam manually in combination with the value and grad of numpyro.infer.util.log_density?
It is tricky for me to interpret the behavior of shard_map. Could you try to isolate the issue?