Defining an arbitrary (unnormalized) density function and sample from it?

The HMC/NUTS samplers take in potential_fn arg, and you should be able to use that directly. e.g.

import torch
from pyro.infer import MCMC, NUTS


def unnormalized_pdf(x):
    return torch.exp(-x['u']**2)


nuts = NUTS(potential_fn=lambda x: -torch.log(unnormalized_pdf(x)))
mcmc = MCMC(nuts, num_samples=100, initial_params={'u': torch.tensor(0.)})
mcmc.run()
print(mcmc.get_samples())
4 Likes