Saving intermediate data for each posterior sample

Hello all,

I currently have a system that uses MCMC with NUTS to sample some parameters in some ODEs. I want to save the last timestep of the ODE’s pytree for each sampled parameter. Rather than regenerating the state from the posterior samples, I was wondering if I could save the pytree at the same time as I calculate incidence.

Some pseudocode for example.

# using diffrax to run odes with the samples parameters from NUTS.
solution = odes.run(initial_state, sampled_params)
model_incidence = jnp.diff(solution.ys)
# saving pytree in some external state
final_timestep.append(solution.ys[-1])
# likelihood step for MCMC
numpyro.sample(
            "incidence",
            Dist.Poisson(model_incidence),
            obs=observed_values,
        )

The current issues I have with this are that A) everything is a Jax tracer still. B) this method only saves the solution object a couple of times as MCMC compiles, but nothing during the actual sampling stage.

Am I going about this the wrong way? Are my issues able to be overcome or is this inherit to the just in time nature of the code.

Sorry for the oddly specified question, I feel I have taken a wrong turn somewhere.

Your method doesn’t work because JAX requires all jitted functions to be functionally pure. Appending to a globally scoped list from within the model is a side effect and thus fails when the model is jitted. See the JAX sharp bits for a more detailed explanation.

You should instead use numpyro.deterministic('last_state', solution.ys[-1]). You can then retrieve the last state for each sample after running MCMC.

1 Like