Pyro and differential equations

Is it possible to infer differential equation parameters using pyro? I found an example with numpyro I was wondering if this is possible with pyro as well?

probably. in principle you just need an appropriate pytorch implementation of some ode solver, perhaps this one.

just note that depending on details inferring ode parameters can be difficult and in any case numpyro is likely to be faster for many ode models.

Many thanks.

Oh that is a very helpful insight. Is numpy faster in general for SVI? Is it better to shift to numpyro?

it’s hard to say in general terms but certainly for something like an ode integrator (which tends to involve long loops) the jax compiler is expected to do a much better job of optimization than pytorch

1 Like