Hi everyone,

I’ve been trying to implement high dimensional ODE models in numpyro using NUTS and this forum has been a great help thusfar! I was able to speed up my code immensely. Using a jax implementation I can now obtain maximum likelihood estimates of my full model in a few minutes.

These estimates are helpful (for instance, I’ve learned that many of the parameters are highly correlated) but I would like to utilize prior information and obtain a posterior distribution over the parameters.

Hence, I want to do a ‘fully Bayesian’ analysis with MCMC. Unfortunately this has been giving me lots of trouble. For starters, my full model (with about 60-70 parameters, excluding likelihood sigma’s) with real data simply does not converge. It takes an immense amount of time and then after a day or so the kernel simply breaks down and resets.

Therefore, I’ve been trying to work with synthetic/fake data and a smaller model. However, even very small models have been challenging and I have yet to efficiently run an ODE model and obtain reasonable rhat statistic and effective sample size, or even properly find the parameters that generated the data back. I’ve also noticed that generally some chains take much longer than others, which suggests that some initializations are more problematic than others. I’ve also tried using stricter priors to little avail. I’ve not yet experimented with different likelihood functions because right now i’m not even adding noise to the fake data.

I’ve created the simplest ODE system I could think of: system of two first order linear ODEs with each one parameter, and a dataset of 3 fake individuals and with 12 timesteps each.

I’ve put the model in this editable colab notebook: Google Colab

Running the model for 4 NUTS chains 1000 warmup + 1000 samples takes >30 minutes with default settings. And still the diagnostics are extremely bad (n_eff = 2, rhat is almost 40) and the parameters are not properly identified, even though a traditional minimizer finds the correct MLE is splitseconds.

Can someone please help me understand why the model performs the way it does? I would love to use numpyro to run the more complex ODE models but right now that seems impossible.