I’d like my model function to have side effects, but JAX won’t allow this in a straightforward way because of JIT compilation. When my model is wrapped in numpyro’s SVI machinery, it’s unclear to me how I could return additional outputs so as to cause my side effects to happen in a purely functional way. Is there, for example, a way to mark certain parameters untrainable by SVI, and somehow set those parameters inside of my model function without angering JAX? or, is there a way to access the return values of a model upon each update step of SVI, so that I can pass those values back in as arguments during the next SVI step?
I believe I may have answered my own question; for the benefit of others reading this and for my own future reference the key component seems to be numpyro’s
mutable function that is not currently listed in the numpyro primitives documentation but you can find its definition here.
Yeah, I think it is best to use
mutable here. We only use it for neural networks with mutable states like running mean of BatchNorm layers. What’s your usage case? We can expose it if it is useful.