I understand that NUTS uses an adaptation scheme to automatically adjust the mass matrix based on the target distribution covariance. However, for my specific problem, I have prior knowledge that I would like to incorporate by using a fixed mass matrix.
I would greatly appreciate it if someone could guide me on how to modify the NUTS sampler in Pyro to include a fixed mass matrix. I want to specify the mass matrix manually, rather than relying on the adaptive scheme.
Any suggestions, code examples, or insights you can provide would be immensely helpful.I have attached a sample code just for elaboration and my problem does not such simple form.
Blockquote
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
def model(data):
loc = torch.zeros(2) # Mean of the multivariate normal
scale = torch.ones(2) # Covariance matrix of the multivariate normal
with pyro.plate(“data”, len(data)):
latent_variable = pyro.sample(“latent_variable”, dist.MultivariateNormal(loc, scale_tril=torch.diag(scale)).to_event(1))
pyro.sample(“observed_data”, dist.MultivariateNormal(latent_variable, torch.eye(2)).to_event(1), obs=data)
observed_data = torch.tensor([[1.0, 2.0], [-1.0, -2.0]])
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
mcmc.run(observed_data)
posterior_samples = mcmc.get_samples()