Your method doesn’t work because JAX requires all jitted functions to be functionally pure. Appending to a globally scoped list from within the model is a side effect and thus fails when the model is jitted. See the JAX sharp bits for a more detailed explanation.
You should instead use numpyro.deterministic('last_state', solution.ys[-1])
. You can then retrieve the last state for each sample after running MCMC.