Handling external log-likelihood functions for NUTS

Hi! I am trying to run NUTS with an external log-likelihood function, implemented in the model as a numpyro.factor, but I am somewhat confused by what the input and output of this funcion should be. Basically, I have a jitted function that takes in a dictionary of {“parameter name”: parameter value} and outputs the corresponding log-lkl value. How do I wrap it so it takes in my numpyro.sample instances?

In more detail:
What I have is

like = jax.jit(candl.tools.get_params_to_logl_func(candl_like, pars_to_theory_specs))
like({'H0': 67.37, 'ombh2': 0.02233, 'omch2': 0.1198, 'logA': 3.043, 'ns': 0.9652, 'tau': 0.054, 'yp': 1.0}) #an example, this returns Array(-153.57355665, dtype=float64)

and the model is something like

def CModel(fiducial):
    cosmo_param = []
    for name, point in fiducial.items():
        cosmo_param.append(numpyro.sample(name, dist.Normal(1.01*point, 0.1*point))) #placeholder prior, e.g. fiducial = {'H0': 67.37, 'ombh2': 0.02233, 'omch2': 0.1198, 'logA': 3.043, 'ns': 0.9652, 'tau': 0.054, 'yp': 1.0}
    numpyro.factor("log_lkl", like_transormed(cosmo_param))

What I’m unsure about is what like_transformed needs to look like

a float

Ok, turns out the issue was just the input value array not being a jax.numpy one, adding that step to the function got it done.