How the mutable_state value propogate to outer SVIStates?

If we use batchnorm layer of Flax Linen in numpyro, the mutable node needs to continuously iterate updated batch_stats. However, I looked the code and have barriers to understand the propagation process of this node value.

  1. In flax_module (contrib/module.py line 114-119):

        params = {"params": params, **nn_state}
        out, new_state = nn_module.apply(params, mutable=mutable, *args,**kwargs)
        new_state = jax.lax.stop_gradient(new_state)
        nn_state.update(**new_state)
        return out
    

    how the last updated nn_state value propagate to outer trace node and further into SVIState container to participate in the iterative process? The outer trace use Msg.copy() to recoed, so nn_state.update can’t modify the correspoding value in the trace container.

  2. In loss_with_mutable_state() (infer/elbo.py line 161-256): It seems no explict method to process mutable node?

Do I miss something?

The loss_fn will execute the flax_module, update the mutable node which contains nn_state (a stateful dictionary), and nn_state.update(**new_state) will update the value of that mutable node. Then, in loss_with_mutable_state, we extract mutable state here: numpyro/numpyro/infer/elbo.py at bf44e0759f3f9b8f2908e17007cfcbf86c005df1 · pyro-ppl/numpyro · GitHub The loss function will return the new state. In each optimization step, we will have a new state stored in SVIState.

Thanks for detailed explanation!