Regarding nullifying priors and gradient issues for pyro.factor statements

post on forum:
Dear PYRO developers/users

Since a couple of weeks I’ve been trying to finish an implementation that
earlier last year I was struggling with, but only now have had the time to pick
back up again.

In a previous post
(Regarding using a custom distribution with HMC) I
asked how to incorporate a custom likelihood calculation in the potential value
for our stochastic model. The advice was to use pyro.factor, which I tried to
do.

For testing purposes I would like to remove / nullify any effect of the priors,
so I can benchmark my implementation against some other code / method. To do
this, I used Get inference for flat prior · Issue #1513 · pyro-ppl/pyro · GitHub
(or at least some form of it).

I have implemented a model (one of many attempts) as follows

def return_LogProbOverloadNormalPriorMVNFactorModel(mvn_configuration):
    """
    Function to return a linefit factor model
    """

    partial_MVN_likelihood_function = partial(
        MVN_likelihood_function, mvn_configuration=mvn_configuration
    )

    # create distribution with overloaded log_prob function
    newNormal = overloaded_dist(dist.Normal)

    def LogProbOverloadNormalPriorMVNFactorModel():
        """
        MVNFactorModel
        """

        x = pyro.sample("x", newNormal(0, 1))
        y = pyro.sample("y", newNormal(0, 1))

        # Add likelihood to function
        model_factor = pyro.factor(
            "model", partial_MVN_likelihood_function(torch.Tensor([x, y]))
        )

    return LogProbOverloadNormalPriorMVNFactorModel

With the newNormal function being an overloaded normal function:

def overloaded_dist(dist_class):
    class newDist(dist_class):
        def log_prob(self, value):
            return value.new_tensor([0.0])

    return newDist

and the mvn_configuration some arbitrary configuration for a 2-d gaussian.

I chose the Normal distribution for the priors since it does not have any
inherent transformations that show up in the potential function calculation, and
just overloading their log_prob would do the job.

With this model, the potential function that pyro builds up returns the
-log_prob that I expect (i.e. just the value of the log_prob of the “model”
site). The problem arises when pyro.infer.mcmc.util.initialise_model tries to
find a good set of initial parameters. And down the line, the calculation of the
gradient of the z_nodes, the following issue arises:

Traceback (most recent call last):
  File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/analysis/model_analysis/main.py", line 75, in <module>
    analyse_model(model=LogProbOverloadNormalPriorMVNFactorModel, config=config, verbose=1, **analysis_settings_dict)
  File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/analysis/model_analysis/analyse_model.py", line 74, in analyse_model
    inspect_via_initialise_model(
  File "/home/david/projects/hmc_project/repo/hmc_project_code/functions/analysis/model_analysis/inspection_functions/inspect_via_initialise_model.py", line 332, in inspect_via_initialise_model
    initial_params = _find_valid_initial_params(
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/infer/mcmc/util.py", line 351, in _find_valid_initial_params
    pe_grad, pe = potential_grad(potential_fn, params)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/pyro/ops/integrator.py", line 85, in potential_grad
    grads = grad(potential_energy, z_nodes)
  File "/home/david/.pyenv/versions/hmc3.9.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 234, in grad
    return Variable._execution_engine.run_backward(
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

This only happens when I nullify both my priors, but even with ‘normal’ priors,
i find some odd behaviour:

  • the NUTS sampling is very inefficient
  • the guide training does not converge well.

Admittedly I have not tested this to a high degree, I will do that in the coming
days, but it would make sense if the pyro.factor statement does not have an
effect on the gradient of the model.

My next step is to explicitly check the effect of the pyro.factor statement on
the gradient, but I did not want to be held up any longer, so I’m posting this already.

My questions now are:

  • is there any way that PYRO/torch would be able to know any grad function
    information about my factor statement? (since it is basically an external
    calculation where the tensors might not be propagated properly).
  • How would I approach solving this situation?
    • Should i provide a custom grad_fn function?
  • If it is my approach of removing the prior effects that is causing this issue,
    what is the proper way of doing so?

Perhaps a clearer way of indicating this issue is to explicitly calculate gradients for two different models.

The first model is a simple prior-only model:

def NormalPrior():
    """
    MVNFactorModel
    """

    x = pyro.sample("x", dist.Normal(0, 1))
    y = pyro.sample("y", dist.Normal(0, 1))

return NormalPrior

for which the gradient of the neg log posterior looks as follows:

\begin{equation} \nabla(-log(p(x,y))) = \{x, y\} \end{equation}

The second model is like the one above, but with an added custom likelihood

def NormalPriorMVNFactorModel():
    """
    MVNFactorModel
    """

    x = pyro.sample("x", dist.Normal(0, 1))
    y = pyro.sample("y", dist.Normal(0, 1))

    # Add likelihood to function
    model_factor = pyro.factor(
        "model", MVN_likelihood_function(torch.Tensor([x, y]))
    )

return NormalPriorMVNFactorModel

Here the MVN_likelihood_function returns the log_prob of a standard 2-d gaussian (i.e. \mathcal{N}(0, 1\kern-0.25em\text{l})).

The analytic form of the gradient of this neg log posterior looks like:

\begin{equation} \nabla(-log(p(x,y))) = \{2x, 2y\} \end{equation}

Now, calculating the gradient in PYRO and evaluating it at e.g. (x=1, y=1) gives the following:

For the NormalPrior model we get

########################################
Model WITHOUT the factor statement, evaluated at {'x': tensor([1.]), 'y': tensor([1.])}
potential value: 2.837877035140991
gradient: {'x': tensor([1.]), 'y': tensor([1.])}

which is what we expect.

For the NormalPriorMVNFactorModel we get

########################################
Model WITH the factor statement, evaluated at {'x': tensor([1.]), 'y': tensor([1.])}
potential value: 5.6757541015503366
gradient: {'x': tensor([1.]), 'y': tensor([1.])}

which is not what we expect (which would be {'x': tensor([2.]), 'y': tensor([2.])}).

This seems to indicate to me that the gradient of our factor statement is just not taken into account.

you’re abusing torch autograd. torch doesn’t follow the gradient tape of elements in this kind of tensor creation op: torch.Tensor([x, y]))

as long as you use torch autograd correctly factor statements will have regular gradients as expected.

a better way to “nullify” a prior might be to wrap it in a poutine.scale context manager with a very small scale like 10^{-10}, see Poutine (Effect handlers) — Pyro documentation

1 Like

Dear @martinjankowiak ,

Thank you for your suggestions. Both of them seem to do the trick! I’ve since read up more about autogradients calculations in Torch and JAX and it makes more sense now.

Coming back to your comment on the tensor creation. Is there a proper way to store input tensors in an array/tensor that allows me to do matrix operations on them, while retaining the individuality of each of the elements (i.e. have their gradient tapes remain intact and register the individual operations that the matrix operation performs).

This might be more Torch-specific though, and better to ask on the Torch forum

David

yes this is a better question for elsewhere but in short you have to use e.g. indexing, masking, cating etc

x = torch.ones(1, requires_grad=True)
y = torch.ones(1, requires_grad=True)
xy1 = torch.cat([x, y])
xy2 = torch.zeros(2)
xy2[0] = x
xy2[1] = y
# etc etc etc


1 Like