Numpyro + Flax?

I think the easiest way is to write a potential_fn, which takes inputs are your “required-to-sample” parameters. If you want to add Normal() prior for a parameter “weight”, just simply add Normal().log_prob(weight) to the joint density. You can also write helpers to make this job easier for you.

Currently, NumPyro supports Stax modules to optimize parameters using SVI. To get inference about parameters of nn modules in MCMC, it is better to use potential_fn as above (this applies for Stax, Flax, Haiku,…)