How is the gradient of distributions calculated?

Hi! In my understanding HMC requires access to the gradient of the probability distribution. How does Pyro for example evaluate the gradient of the pdf of a beta distribution? How can I expose information about the gradient for custom distributions I wrote myself (especially in NumPyro)?

Pyro and NumPyro make use of automatic differentiation in their underlying tensor frameworks PyTorch and JAX respectively. If you write your function using differentiable PyTorch or JAX primitives, the gradient is defined automatically.

In the unlikely event that your function cannot be expressed using builtin primitive operations with predefined gradients, both frameworks allow you to implement a custom operation with a custom gradient function - see this PyTorch tutorial or this JAX tutorial for instructions.

1 Like