Hi! I am trying to run NUTS with an external log-likelihood function, implemented in the model as a numpyro.factor, but I am somewhat confused by what the input and output of this funcion should be. Basically, I have a jitted function that takes in a dictionary of {“parameter name”: parameter value} and outputs the corresponding log-lkl value. How do I wrap it so it takes in my numpyro.sample instances?
In more detail:
What I have is
like = jax.jit(candl.tools.get_params_to_logl_func(candl_like, pars_to_theory_specs))
like({'H0': 67.37, 'ombh2': 0.02233, 'omch2': 0.1198, 'logA': 3.043, 'ns': 0.9652, 'tau': 0.054, 'yp': 1.0}) #an example, this returns Array(-153.57355665, dtype=float64)
and the model is something like
def CModel(fiducial):
cosmo_param = []
for name, point in fiducial.items():
cosmo_param.append(numpyro.sample(name, dist.Normal(1.01*point, 0.1*point))) #placeholder prior, e.g. fiducial = {'H0': 67.37, 'ombh2': 0.02233, 'omch2': 0.1198, 'logA': 3.043, 'ns': 0.9652, 'tau': 0.054, 'yp': 1.0}
numpyro.factor("log_lkl", like_transormed(cosmo_param))
What I’m unsure about is what like_transformed
needs to look like