Filter result of ODE integration for inference


I have a model which is based on integrating a d-dimensional ordinary differential equation (ODE), where the dimensionality is given by the parameter n_components.

sigma = numpyro.sample("sigma", dist.LogNormal(-1, 1).expand([n_components]))
y = odeint(dydt, ...)
numpyro.sample("y", dist.Normal(y, sigma), obs=y)

The result of the ODE integration (y) is an array of shape (n_timesteps x n_components). And in the code written above inference is performed over the whole output array y. In my case I want to further filter this, meaning that depending on the component, I want to consider a different subset of timesteps in the inference.

An example would be a system with n_components = 3 and 10 time steps in the ODE integration leading to 10x3 array for y. Now I want to only consider for component 0 the timesteps [0,1,2,5], for component 1 the timesteps [1,2,8,9] and for component 3 [0,2,3,4,5,7,8] in the inference.

What is the best way to do this?



what does “consider” mean? does it mean include an observation for? does it mean something else?

It means include an oberservation for. Of course the odeint is always returning the complete matrix y (n_timesteps x n_components). The final y_obs could then be of the described form. Instead of an array it would be a list of lists, one sublist for one component.

noticed that I have a typo in the code above. It is of course:

numpyro.sample("y", dist.Normal(y, sigma), obs=y_obs)

Any idea how to setup this in the best way?

you should be able to use mask. something like:

numpyro.sample("y", dist.Normal(y, sigma).mask(my_mask), obs=y_obs)

Thanks, I will give it a try.

One question, of which form has to be y_obs then? Would it be an array of shape n_timesteps x n_components in which the masked elements are just ignored and could be anything? Or would it be a list of lists (or 1-dim arrays) in which only the valid observations are in?

it would be a full array. you need this for vectorization/speed.

the masked out elements are arbitrary and will be ignored. although you probably want to make sure they’re not NaN

Thanks for your help, I will set them to zero and give it a try!