Log probability and gradient of a Pyro model

I have two implementations of a numerical Bayesian inference algorithm written in Numpy and Torch and would like to interface them with NumPyro and Pyro, respectively. With NumPyro, I have leveraged the log_density method of numpyro.infer.util to calculate the log probability of a model and Jax’s autodiff library to calculate the gradient of a model.

I would now like to do the same in Pyro but am unsure where to begin. I have not found the equivalent log_density method in Pyro and am unsure if this is because I am looking in the wrong place, or because it is not implemented in Pyro.

Any pointers are appreciated!

In Pyro you can compute log_density as follows using the poutine.trace:

trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
log_density = trace.log_prob_sum()