Natural gradient ascent


#1

Hi everyone,

I would like to implement the natural gradient ascent over ELBO, however I have hard time figuring out if I should use existing classes in pyro.optim and pyro.infer or implement a completely new inference class.

Algorithm should be straightforward to implement

  1. iterate over sample sites of the guide and for each site compute inverse fisher information matrix G(param) dependent on the current parameter values.

  2. compute ELBO and corresponding stochastic gradients for each parameter, SG(param).

  3. Update parameters using natural gradient ascent: param -> param + lr*G(param)*SG(param)

My current idea is to use SGD optimiser (without momentum) with Trace_ELBO, get gradients after each update step, and modify parameter values given natural gradients instead of canonical gradients provided by the SGD optimiser.

I would appreciate an expert opinion on this. Do you think that this would be a reasonable good implementation or do you see a more efficient way for implementing the algorithm?


#2

Hi @pyroman, I think a clean way to implement this would be at the distribution layer, completely independent of SVI and EBLO machinery. For example you could create special distributions that were parametrized by natural parameters, say

class NaturalNormal(Normal):
    def __init__(self, nat_param):
        loc, scale = some_differentiable_transform(nat_param)
        super(NaturalNormal, self).__init__(loc, scale)

If nat_param is unconstrained, then natural gradient ascent should “just work”

- loc = pyro.param("loc", ...)
- scale = pyro.param("scale", ...)
- x = pyro.sample("x", Normal(loc, scale))
+ nat_param = pyro.param("nat_param", ...)
+ x = pyro.sample("x", NaturalNormal(nat_param))

If there are constraints, we might need to add some isometric transforms, e.g. AbsTransform.

If this low-level version works, you could then make it easier to use via a program transform, e.g. making a natural version of pyro.lift.

Also note that there is already an ExponentialFamily distribution in torch.distributions, so you may be able to automate some of the parameterizations.

This is just an idea :smile: there are many ways to implement this. Let me know what you think.


#3

Hi @fritzo, thanks for the suggestion. That is an interesting and a very clean solution. I will have a look at the classes and functions you suggested.

The only hurdle I see here is that there might not be a “natural” parametrisation of an arbitrary distribution, even if it belongs to the exponential family. My guess is that such parametrisation only exists for natural exponential family (e.g. normal distribution with known variance, or gamma distribution with known shape).

As long as the inverse fisher information is different from the identity matrix, the Riemanninan metric of the statistical manifold will be non-euclidian. This implies that one has to multiply normal gradient vector estimated over parameters of that distribution with a corresponding Riemannian metric.

For example, if my requirement would be to pass to an optimiser not only gradients but also manifold curvature, what would be an elegant way to solve this? My current solution is to use parameter names, which include the distribution type and compute within the optimiser the corresponding Riemannian metric, that is, gradient multipliers (this works only if the information matrix can be made diagonal). However, I am aware that this is a dirty solution as it does not fit well with the general Pyro structure.


#4

Ah good point, I hadn’t thought about curvature!

I think a distribution-layer solution is still workable. To pass manifold curvature tensor to the optimizer, I would set a custom attribute on the optimized torch.Tensor object, such as ._pyro_metric. (I usually prefix custom attributes with underscore to avoid name conflict with builtins, and I usually prefix with _pyro_ so that it is clear in a debugger which attributes are being set by which library; we use this technique for things like ._pyro_backward, ._pyro_dims, and the .unconstrained weakref in to handle constraints.) Then you can implement a RiemannianAdam or something that checks if hasattr(param, '_pyro_metric') and if present use it as a preconditioner for the gradient (or something). WDYT? :smile:


#5

Great, thank you. This is very helpful.