NUTS sampling with an array of parameters

Hi. I’m attempting to build a NumPyro-based NUTS inference tool for cosmological tasks. The model set-up is as follows:

  • a number of parameters are provided, passed from input as an array, the number of parameters and their names vary depending on the datasets used, so the input has to be generic
  • each parameter comes with a prior
  • an external JAX-differentiable log-likelihood function is provided

Currently I’m just trying to implement the most bare-bones model, basically a recreation of the last part of this example but in NumPyro instead of BlackJAX.
I am confused about the syntax of defining a model in NumPyro, specifically the numpyro.sample declarations. Every example I’ve seen goes about defining each parameter separately, but that just doesn’t work for me since the number and names of model parameters are not known in advance. Is there a way to do this “in bulk”?

You can define a model like

def model(param_priors):
    params = []
    for name, prior in param_priors.items():
        params.append(numpyro.sample(name, prior))
1 Like