Regarding using a custom distribution with HMC

Dear developers and other PYRO experts,

I have been trying to set up an HMC/MCMC/NUTS sampling routine with PYRO, specifically by providing my own likelihood function that ‘describes’ the posterior distribution but is not constructed by a combination of torch distributions. Until now I have not been able to figure out how to solve my issues, and there are little examples that come close to my use-case (which may mean I am either over complicating things, or missing the clue as to how to approach it entirely)

My questions then are:

  • How to construct a custom distribution such that PYRO knows how to use it as the distribution that needs to be sampled.

    • Which class methods are necessary to do this

    • How to set up the model/distribution instance properly

For context, I am currently trying to sample a 5 dimensional model+observations (of which the likelihood is described through a function posterior_function that compares the data with results of the model). My ultimate goal is to use the Neutra NUTS sampling and modifications thereof on my model, this perhaps will introduce other issues later down the line but currently I am stuck with the NUTS sampling.

My code setup is as follows:

The distribution class:


class BayesFitClass(pyro.distributions.Distribution):
    """
    BayesFit class
    """

    support = constraints.real

    def __init__(self, posterior_function, data):
        """
        Init function to set up the BayesFit model.

        Input:
            data_filename: filename for the pickle datafile
        """

        self.posterior_function = posterior_function
        self.data = data
        super(pyro.distributions.Distribution, self).__init__()

    def sample(self, sample_shape=torch.Size()):
        """
        Sample method for the BayesFit posterior
        """

        # Parameters
        slopein                 = pyro.sample('slopein', dist.Uniform(0 + 1e-4, 3))
        log10_slopeout          = pyro.sample('log10_slopeout', dist.Uniform(np.log10(3), np.log10(4e2)))
        log10_J0                = pyro.sample('log10_J0', dist.Uniform(0, 4))
        log10_rho0              = pyro.sample('log10_rho0', dist.Uniform(3, 9))
        log10_Rs                = pyro.sample('log10_Rs', dist.Uniform(-4, 2))

        return torch.Tensor([slopein, log10_slopeout, log10_J0, log10_rho0, log10_Rs])

    def log_prob(self, sample_tensor):
        """
        Function to get the log-probability from the posterior
        """

        # Put in param dict shape
        param_dict = {
            'slopein':          sample_tensor[0],
            'log10_slopeout':   sample_tensor[1],
            'log10_J0':         sample_tensor[2],
            'log10_rho0':       sample_tensor[3],
            'log10_Rs':         sample_tensor[4],
        }

        return self.posterior_function(
            param_dict=param_dict,
            data=self.data
        )

    def batch_shape():
        """
        TODO: add batch shape property
        """

    def event_shape():
        """
        TODO: add even shape property
        """
        return 5

Some distribution and model instance creation functions:


def return_BayesFitClass_instance(**bayesfit_class_init_args):
    """
    Function to return BayesFitClass instance
    """

    BayesFitClass_instance = BayesFitClass(**bayesfit_class_init_args)

    return BayesFitClass_instance

def new_BayesFitClass_model(**bayesfit_class_init_args):
    """
    Wrapper Function to return a model function
    """

    def model():

        return pyro.sample('model', return_BayesFitClass_instance(**bayesfit_class_init_args))

    return model

def return_BayesFitClass_model(**bayesfit_class_init_args):
    """
    Function that returns a model method for the sampling of the BayesFitClass
    """

    #
    return new_BayesFitClass_model(**bayesfit_class_init_args)

Running a NUTS sampling with this model currently leads to the following error:


[generate_samples.py:143 -     generate_samples ] 2022-07-26 23:18:57,580: Running NUTS sampler directly on target distribution

Warmup:   0%|          | 0/150 [00:00, ?it/s]Called bayesfit_posterior_function:
    with parameters: {'slopein': tensor(-1.4403, grad_fn=<SelectBackward0>), 'log10_slopeout': tensor(1.1604, grad_fn=<SelectBackward0>), 'log10_J0': tensor(0.9963, grad_fn=<SelectBackward0>), 'log10_rho0': tensor(-0.4383, grad_fn=<SelectBackward0>), 'log10_Rs': tensor(-0.3154, grad_fn=<SelectBackward0>)}
    norm: 0.00011603959196843638
    evaluated_df: [8.60935879e-22 2.51503576e-32 8.86303219e-32 ... 2.89534001e-12
 3.89460097e-33 5.07838347e-36]
    any((evaluated_df/norm)==0): False
Traceback (most recent call last):
  File "/home/david/projects/hmc_project/repo/hmc_project_code/projects/usecase_projects/project_paula_gherghinescu/main.py", line 66, in <module>
    generate_samples(
  File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/sampling/generate_samples.py", line 147, in generate_samples
    nuts_on_target_results = sampler_wrapper(
  File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/sampling/generate_samples.py", line 39, in sampler_wrapper
    sampler_results = sampler_function(
  File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/sampling/nuts_aahmc_sampler.py", line 53, in nuts_aahmc_sampler
    mcmc.run()
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/api.py", line 563, in run
    for x, chain_id in self.sampler.run(*args, **kwargs):
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/api.py", line 223, in run
    for sample in _gen_samples(
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/api.py", line 144, in _gen_samples
    kernel.setup(warmup_steps, *args, **kwargs)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/hmc.py", line 325, in setup
    self._initialize_model_properties(args, kwargs)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/hmc.py", line 259, in _initialize_model_properties
    init_params, potential_fn, transforms, trace = initialize_model(
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/util.py", line 468, in initialize_model
    initial_params = _find_valid_initial_params(
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/util.py", line 351, in _find_valid_initial_params
    pe_grad, pe = potential_grad(potential_fn, params)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/ops/integrator.py", line 76, in potential_grad
    potential_energy = potential_fn(z)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/util.py", line 281, in _potential_fn
    log_joint = self.trace_prob_evaluator.log_prob(model_trace)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/util.py", line 238, in log_prob
    return model_trace.log_prob_sum()
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/poutine/trace_struct.py", line 207, in log_prob_sum
    log_p = scale_and_mask(log_p, site["scale"], site["mask"]).sum()
AttributeError: 'float' object has no attribute 'sum'

Which I think is caused by the output of some of the class methods not returning the correct type of output, hinting at my lack of understanding of what output they should give.

Current versions of the packages are


pyro-api==0.1.2
pyro-ppl==1.8.0
torch==1.10.2

I would be greatly helped if someone with more experience in this could hint me towards a solution or a write-up of how to approach this type of sampling. Most of the tutorials I find use existing torch distributions, and the custom distributions that are included in the code I have not understood fully (some of them have many extra functions that I do not grasp the purpose of).

it’s probably sufficient to use factor; see other posts and the repo for examples/details, e.g. this post

Dear @martinjankowiak,

apologies for the late reply and reactivating this thread, but I want to ask some follow up questions.

So, the goal that I have, i.e. to MCMC/HMC a distribution that is not built up by pytorch distribution objects, is doable according to you?

In the post you linked to I read the following:

The latter form using pyro.factor() can be useful if you have a bunch of PyTorch code to compute a (possibly non-normalized) likelihood like fn(loc, data) , but it is inconvenient to wrap that code in a Distribution interface.

This does sound like my situation, I ‘just’ want to introduce a custom log-likelihood function to sample any given model (I have several use-cases so it has to be general enough) and not per se create my custom distribution object.

The model example given in that post is a little bit too simplified for me, specifically the sample part:

def model_2(data):
   loc = pyro.sample("loc", dist.Normal(0, 1))
   pyro.factor("obs", dist.Normal(loc, 1).log_prob(data))

In my case, I have some arbitrary model, with a set of variables, but I do not know their distribution. I interpret loc = pyro.sample("loc", dist.Normal(0, 1)) means that they assume that that variable is distributed as a normal distribution, but I might be misinterpreting this. Could it be that the loc = .. acts as the prior for that variable?

There are some examples in the documentation like Example: Epidemiological inference via HMC — Pyro Tutorials 1.8.6 documentation in which the function vectorized_model looks like it might do what I want. Can the two sample calls to S_aux and I_aux be regarded as priors?

@DavidDouwe hi david please refer to the documentation: yes sample statements encode prior information.

viewed operationally, HMC is an algorithm that takes a (possibly unnormalized) log density and generates approximate samples from the corresponding density. so for example the log density -0.5 x ** 2 corresponds to a unit normal gaussian distribution. each sample statement essentially encodes part of the log density. using a factor statement is another way to add additional terms to the log density of a model.