MCMC, SVI optimization and pytrees, eg equinox

Hi, I am looking to integrate an equinox neural network (which is just a pytree) into numpyro. I would like to ensure that the network weights are trained and adjusted in the context of a broader probabilistic model. There are few post and comments on equinox and numpyro on the forum, but I wanted to double check a few things.

I stumbled upon the issue comment here and compared it to the documentation:

I just want to confirm that numpyro.param is a constant in an MCMC inference context, but an optimizable parameter in an SVI context?

For SVI then if I do not want priors on the network weights, I can include the network parameters using numpyro.param and it would be optimized as expected through something like optax.adam?

And for an MCMC context I would have to include priors by wrapping/replacing the weights in the equinox pytree with the output of numpyro.sample, correct? If so, would the weights then be trained/adjusted too?

yes to all of your questions. In particular, see e.g. random_flax_module

2 Likes

thank you very much, appreciate it @fehiepsi

Thanks @fehiepsi