Predefined fixed mass matrix for NUTS

I understand that NUTS uses an adaptation scheme to automatically adjust the mass matrix based on the target distribution covariance. However, for my specific problem, I have prior knowledge that I would like to incorporate by using a fixed mass matrix.

I would greatly appreciate it if someone could guide me on how to modify the NUTS sampler in Pyro to include a fixed mass matrix. I want to specify the mass matrix manually, rather than relying on the adaptive scheme.

Any suggestions, code examples, or insights you can provide would be immensely helpful.I have attached a sample code just for elaboration and my problem does not such simple form.

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
def model(data):
loc = torch.zeros(2) # Mean of the multivariate normal
scale = torch.ones(2) # Covariance matrix of the multivariate normal
with pyro.plate(“data”, len(data)):
latent_variable = pyro.sample(“latent_variable”, dist.MultivariateNormal(loc, scale_tril=torch.diag(scale)).to_event(1))
pyro.sample(“observed_data”, dist.MultivariateNormal(latent_variable, torch.eye(2)).to_event(1), obs=data)
observed_data = torch.tensor([[1.0, 2.0], [-1.0, -2.0]])
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
posterior_samples = mcmc.get_samples()

Hi, I’m coming at this from NumPyro experience rather than Pyro, but I understand them to be similar in the broad sense.

In NumPyro, you can pre-specify a mass matrix when defining the NUTS kernel that goes into the MCMC sampler object, e.g.:

kernel = NUTS(model, inverse_mass_matrix = [SOME MATRIX]).

I’m not sure if base Pyro has the same utility.