Filter result of ODE integration for inference


I have a model which is based on integrating a d-dimensional ordinary differential equation (ODE), where the dimensionality is given by the parameter n_components.

sigma = numpyro.sample("sigma", dist.LogNormal(-1, 1).expand([n_components]))
y = odeint(dydt, ...)
numpyro.sample("y", dist.Normal(y, sigma), obs=y)

The result of the ODE integration (y) is an array of shape (n_timesteps x n_components). And in the code written above inference is performed over the whole output array y. In my case I want to further filter this, meaning that depending on the component, I want to consider a different subset of timesteps in the inference.

An example would be a system with n_components = 3 and 10 time steps in the ODE integration leading to 10x3 array for y. Now I want to only consider for component 0 the timesteps [0,1,2,5], for component 1 the timesteps [1,2,8,9] and for component 3 [0,2,3,4,5,7,8] in the inference.

What is the best way to do this?



what does “consider” mean? does it mean include an observation for? does it mean something else?

1 Like

It means include an oberservation for. Of course the odeint is always returning the complete matrix y (n_timesteps x n_components). The final y_obs could then be of the described form. Instead of an array it would be a list of lists, one sublist for one component.

noticed that I have a typo in the code above. It is of course:

numpyro.sample("y", dist.Normal(y, sigma), obs=y_obs)

Any idea how to setup this in the best way?

you should be able to use mask. something like:

numpyro.sample("y", dist.Normal(y, sigma).mask(my_mask), obs=y_obs)

Thanks, I will give it a try.

One question, of which form has to be y_obs then? Would it be an array of shape n_timesteps x n_components in which the masked elements are just ignored and could be anything? Or would it be a list of lists (or 1-dim arrays) in which only the valid observations are in?

it would be a full array. you need this for vectorization/speed.

the masked out elements are arbitrary and will be ignored. although you probably want to make sure they’re not NaN

Thanks for your help, I will set them to zero and give it a try!