Hi, I am looking to integrate an equinox neural network (which is just a pytree) into numpyro. I would like to ensure that the network weights are trained and adjusted in the context of a broader probabilistic model. There are few post and comments on equinox and numpyro on the forum, but I wanted to double check a few things.
I stumbled upon the issue comment here and compared it to the documentation:
I just want to confirm that numpyro.param
is a constant in an MCMC inference context, but an optimizable parameter in an SVI context?
For SVI then if I do not want priors on the network weights, I can include the network parameters using numpyro.param
and it would be optimized as expected through something like optax.adam
?
And for an MCMC context I would have to include priors by wrapping/replacing the weights in the equinox pytree with the output of numpyro.sample
, correct? If so, would the weights then be trained/adjusted too?