If we use batchnorm layer of Flax Linen in numpyro, the mutable
node needs to continuously iterate updated batch_stats
. However, I looked the code and have barriers to understand the propagation process of this node value.
-
In
flax_module
(contrib/module.py
line 114-119):params = {"params": params, **nn_state} out, new_state = nn_module.apply(params, mutable=mutable, *args,**kwargs) new_state = jax.lax.stop_gradient(new_state) nn_state.update(**new_state) return out
how the last updated
nn_state
value propagate to outer trace node and further into SVIState container to participate in the iterative process? The outer trace use Msg.copy() to recoed, sonn_state.update
can’t modify the correspoding value in the trace container. -
In
loss_with_mutable_state()
(infer/elbo.py
line 161-256): It seems no explict method to processmutable
node?
Do I miss something?