post on forum:
Dear PYRO developers/users
Since a couple of weeks I’ve been trying to finish an implementation that
earlier last year I was struggling with, but only now have had the time to pick
back up again.
In a previous post
(Regarding using a custom distribution with HMC) I
asked how to incorporate a custom likelihood calculation in the potential value
for our stochastic model. The advice was to use pyro.factor
, which I tried to
do.
For testing purposes I would like to remove / nullify any effect of the priors,
so I can benchmark my implementation against some other code / method. To do
this, I used Get inference for flat prior · Issue #1513 · pyro-ppl/pyro · GitHub
(or at least some form of it).
I have implemented a model (one of many attempts) as follows
def return_LogProbOverloadNormalPriorMVNFactorModel(mvn_configuration):
"""
Function to return a linefit factor model
"""
partial_MVN_likelihood_function = partial(
MVN_likelihood_function, mvn_configuration=mvn_configuration
)
# create distribution with overloaded log_prob function
newNormal = overloaded_dist(dist.Normal)
def LogProbOverloadNormalPriorMVNFactorModel():
"""
MVNFactorModel
"""
x = pyro.sample("x", newNormal(0, 1))
y = pyro.sample("y", newNormal(0, 1))
# Add likelihood to function
model_factor = pyro.factor(
"model", partial_MVN_likelihood_function(torch.Tensor([x, y]))
)
return LogProbOverloadNormalPriorMVNFactorModel
With the newNormal function being an overloaded normal function:
def overloaded_dist(dist_class):
class newDist(dist_class):
def log_prob(self, value):
return value.new_tensor([0.0])
return newDist
and the mvn_configuration some arbitrary configuration for a 2-d gaussian.
I chose the Normal distribution for the priors since it does not have any
inherent transformations that show up in the potential function calculation, and
just overloading their log_prob would do the job.
With this model, the potential function that pyro builds up returns the
-log_prob that I expect (i.e. just the value of the log_prob of the “model”
site). The problem arises when pyro.infer.mcmc.util.initialise_model
tries to
find a good set of initial parameters. And down the line, the calculation of the
gradient of the z_nodes, the following issue arises:
Traceback (most recent call last):
File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/analysis/model_analysis/main.py", line 75, in <module>
analyse_model(model=LogProbOverloadNormalPriorMVNFactorModel, config=config, verbose=1, **analysis_settings_dict)
File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/analysis/model_analysis/analyse_model.py", line 74, in analyse_model
inspect_via_initialise_model(
File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/analysis/model_analysis/inspection_functions/inspect_via_initialise_model.py", line 332, in inspect_via_initialise_model
initial_params = _find_valid_initial_params(
File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/util.py", line 351, in _find_valid_initial_params
pe_grad, pe = potential_grad(potential_fn, params)
File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/ops/integrator.py", line 85, in potential_grad
grads = grad(potential_energy, z_nodes)
File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 234, in grad
return Variable._execution_engine.run_backward(
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
This only happens when I nullify both my priors, but even with ‘normal’ priors,
i find some odd behaviour:
- the NUTS sampling is very inefficient
- the guide training does not converge well.
Admittedly I have not tested this to a high degree, I will do that in the coming
days, but it would make sense if the pyro.factor statement does not have an
effect on the gradient of the model.
My next step is to explicitly check the effect of the pyro.factor statement on
the gradient, but I did not want to be held up any longer, so I’m posting this already.
My questions now are:
- is there any way that PYRO/torch would be able to know any grad function
information about my factor statement? (since it is basically an external
calculation where the tensors might not be propagated properly). - How would I approach solving this situation?
- Should i provide a custom grad_fn function?
- If it is my approach of removing the prior effects that is causing this issue,
what is the proper way of doing so?