Hi. I’m attempting to build a NumPyro-based NUTS inference tool for cosmological tasks. The model set-up is as follows:
- a number of parameters are provided, passed from input as an array, the number of parameters and their names vary depending on the datasets used, so the input has to be generic
- each parameter comes with a prior
- an external JAX-differentiable log-likelihood function is provided
Currently I’m just trying to implement the most bare-bones model, basically a recreation of the last part of this example but in NumPyro instead of BlackJAX.
I am confused about the syntax of defining a model in NumPyro, specifically the numpyro.sample declarations. Every example I’ve seen goes about defining each parameter separately, but that just doesn’t work for me since the number and names of model parameters are not known in advance. Is there a way to do this “in bulk”?