Dear developers and other PYRO experts,
I have been trying to set up an HMC/MCMC/NUTS sampling routine with PYRO, specifically by providing my own likelihood function that ‘describes’ the posterior distribution but is not constructed by a combination of torch distributions. Until now I have not been able to figure out how to solve my issues, and there are little examples that come close to my usecase (which may mean I am either over complicating things, or missing the clue as to how to approach it entirely)
My questions then are:

How to construct a custom distribution such that PYRO knows how to use it as the distribution that needs to be sampled.

Which class methods are necessary to do this

How to set up the model/distribution instance properly

For context, I am currently trying to sample a 5 dimensional model+observations (of which the likelihood is described through a function posterior_function
that compares the data with results of the model). My ultimate goal is to use the Neutra NUTS sampling and modifications thereof on my model, this perhaps will introduce other issues later down the line but currently I am stuck with the NUTS sampling.
My code setup is as follows:
The distribution class:
class BayesFitClass(pyro.distributions.Distribution):
"""
BayesFit class
"""
support = constraints.real
def __init__(self, posterior_function, data):
"""
Init function to set up the BayesFit model.
Input:
data_filename: filename for the pickle datafile
"""
self.posterior_function = posterior_function
self.data = data
super(pyro.distributions.Distribution, self).__init__()
def sample(self, sample_shape=torch.Size()):
"""
Sample method for the BayesFit posterior
"""
# Parameters
slopein = pyro.sample('slopein', dist.Uniform(0 + 1e4, 3))
log10_slopeout = pyro.sample('log10_slopeout', dist.Uniform(np.log10(3), np.log10(4e2)))
log10_J0 = pyro.sample('log10_J0', dist.Uniform(0, 4))
log10_rho0 = pyro.sample('log10_rho0', dist.Uniform(3, 9))
log10_Rs = pyro.sample('log10_Rs', dist.Uniform(4, 2))
return torch.Tensor([slopein, log10_slopeout, log10_J0, log10_rho0, log10_Rs])
def log_prob(self, sample_tensor):
"""
Function to get the logprobability from the posterior
"""
# Put in param dict shape
param_dict = {
'slopein': sample_tensor[0],
'log10_slopeout': sample_tensor[1],
'log10_J0': sample_tensor[2],
'log10_rho0': sample_tensor[3],
'log10_Rs': sample_tensor[4],
}
return self.posterior_function(
param_dict=param_dict,
data=self.data
)
def batch_shape():
"""
TODO: add batch shape property
"""
def event_shape():
"""
TODO: add even shape property
"""
return 5
Some distribution and model instance creation functions:
def return_BayesFitClass_instance(**bayesfit_class_init_args):
"""
Function to return BayesFitClass instance
"""
BayesFitClass_instance = BayesFitClass(**bayesfit_class_init_args)
return BayesFitClass_instance
def new_BayesFitClass_model(**bayesfit_class_init_args):
"""
Wrapper Function to return a model function
"""
def model():
return pyro.sample('model', return_BayesFitClass_instance(**bayesfit_class_init_args))
return model
def return_BayesFitClass_model(**bayesfit_class_init_args):
"""
Function that returns a model method for the sampling of the BayesFitClass
"""
#
return new_BayesFitClass_model(**bayesfit_class_init_args)
Running a NUTS sampling with this model currently leads to the following error:
[generate_samples.py:143  generate_samples ] 20220726 23:18:57,580: Running NUTS sampler directly on target distribution
Warmup: 0%  0/150 [00:00, ?it/s]Called bayesfit_posterior_function:
with parameters: {'slopein': tensor(1.4403, grad_fn=<SelectBackward0>), 'log10_slopeout': tensor(1.1604, grad_fn=<SelectBackward0>), 'log10_J0': tensor(0.9963, grad_fn=<SelectBackward0>), 'log10_rho0': tensor(0.4383, grad_fn=<SelectBackward0>), 'log10_Rs': tensor(0.3154, grad_fn=<SelectBackward0>)}
norm: 0.00011603959196843638
evaluated_df: [8.60935879e22 2.51503576e32 8.86303219e32 ... 2.89534001e12
3.89460097e33 5.07838347e36]
any((evaluated_df/norm)==0): False
Traceback (most recent call last):
File "/home/david/projects/hmc_project/repo/hmc_project_code/projects/usecase_projects/project_paula_gherghinescu/main.py", line 66, in <module>
generate_samples(
File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/sampling/generate_samples.py", line 147, in generate_samples
nuts_on_target_results = sampler_wrapper(
File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/sampling/generate_samples.py", line 39, in sampler_wrapper
sampler_results = sampler_function(
File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/sampling/nuts_aahmc_sampler.py", line 53, in nuts_aahmc_sampler
mcmc.run()
File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/sitepackages/pyro/poutine/messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/sitepackages/pyro/infer/mcmc/api.py", line 563, in run
for x, chain_id in self.sampler.run(*args, **kwargs):
File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/sitepackages/pyro/infer/mcmc/api.py", line 223, in run
for sample in _gen_samples(
File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/sitepackages/pyro/infer/mcmc/api.py", line 144, in _gen_samples
kernel.setup(warmup_steps, *args, **kwargs)
File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/sitepackages/pyro/infer/mcmc/hmc.py", line 325, in setup
self._initialize_model_properties(args, kwargs)
File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/sitepackages/pyro/infer/mcmc/hmc.py", line 259, in _initialize_model_properties
init_params, potential_fn, transforms, trace = initialize_model(
File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/sitepackages/pyro/infer/mcmc/util.py", line 468, in initialize_model
initial_params = _find_valid_initial_params(
File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/sitepackages/pyro/infer/mcmc/util.py", line 351, in _find_valid_initial_params
pe_grad, pe = potential_grad(potential_fn, params)
File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/sitepackages/pyro/ops/integrator.py", line 76, in potential_grad
potential_energy = potential_fn(z)
File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/sitepackages/pyro/infer/mcmc/util.py", line 281, in _potential_fn
log_joint = self.trace_prob_evaluator.log_prob(model_trace)
File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/sitepackages/pyro/infer/mcmc/util.py", line 238, in log_prob
return model_trace.log_prob_sum()
File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/sitepackages/pyro/poutine/trace_struct.py", line 207, in log_prob_sum
log_p = scale_and_mask(log_p, site["scale"], site["mask"]).sum()
AttributeError: 'float' object has no attribute 'sum'
Which I think is caused by the output of some of the class methods not returning the correct type of output, hinting at my lack of understanding of what output they should give.
Current versions of the packages are
pyroapi==0.1.2
pyroppl==1.8.0
torch==1.10.2
I would be greatly helped if someone with more experience in this could hint me towards a solution or a writeup of how to approach this type of sampling. Most of the tutorials I find use existing torch distributions, and the custom distributions that are included in the code I have not understood fully (some of them have many extra functions that I do not grasp the purpose of).