Hi everyone,
I would like to implement the natural gradient ascent over ELBO, however I have hard time figuring out if I should use existing classes in pyro.optim and pyro.infer or implement a completely new inference class.
Algorithm should be straightforward to implement

iterate over sample sites of the guide and for each site compute inverse fisher information matrix G(param) dependent on the current parameter values.

compute ELBO and corresponding stochastic gradients for each parameter, SG(param).

Update parameters using natural gradient ascent: param > param + lr*G(param)*SG(param)
My current idea is to use SGD optimiser (without momentum) with Trace_ELBO, get gradients after each update step, and modify parameter values given natural gradients instead of canonical gradients provided by the SGD optimiser.
I would appreciate an expert opinion on this. Do you think that this would be a reasonable good implementation or do you see a more efficient way for implementing the algorithm?
Hi @pyroman, I think a clean way to implement this would be at the distribution layer, completely independent of SVI and EBLO machinery. For example you could create special distributions that were parametrized by natural parameters, say
class NaturalNormal(Normal):
def __init__(self, nat_param):
loc, scale = some_differentiable_transform(nat_param)
super(NaturalNormal, self).__init__(loc, scale)
If nat_param
is unconstrained, then natural gradient ascent should â€śjust workâ€ť
 loc = pyro.param("loc", ...)
 scale = pyro.param("scale", ...)
 x = pyro.sample("x", Normal(loc, scale))
+ nat_param = pyro.param("nat_param", ...)
+ x = pyro.sample("x", NaturalNormal(nat_param))
If there are constraints, we might need to add some isometric transforms, e.g. AbsTransform
.
If this lowlevel version works, you could then make it easier to use via a program transform, e.g. making a natural version of pyro.lift.
Also note that there is already an ExponentialFamily distribution in torch.distributions, so you may be able to automate some of the parameterizations.
This is just an idea there are many ways to implement this. Let me know what you think.
Hi @fritzo, thanks for the suggestion. That is an interesting and a very clean solution. I will have a look at the classes and functions you suggested.
The only hurdle I see here is that there might not be a â€śnaturalâ€ť parametrisation of an arbitrary distribution, even if it belongs to the exponential family. My guess is that such parametrisation only exists for natural exponential family (e.g. normal distribution with known variance, or gamma distribution with known shape).
As long as the inverse fisher information is different from the identity matrix, the Riemanninan metric of the statistical manifold will be noneuclidian. This implies that one has to multiply normal gradient vector estimated over parameters of that distribution with a corresponding Riemannian metric.
For example, if my requirement would be to pass to an optimiser not only gradients but also manifold curvature, what would be an elegant way to solve this? My current solution is to use parameter names, which include the distribution type and compute within the optimiser the corresponding Riemannian metric, that is, gradient multipliers (this works only if the information matrix can be made diagonal). However, I am aware that this is a dirty solution as it does not fit well with the general Pyro structure.
Ah good point, I hadnâ€™t thought about curvature!
I think a distributionlayer solution is still workable. To pass manifold curvature tensor to the optimizer, I would set a custom attribute on the optimized torch.Tensor
object, such as ._pyro_metric
. (I usually prefix custom attributes with underscore to avoid name conflict with builtins, and I usually prefix with _pyro_
so that it is clear in a debugger which attributes are being set by which library; we use this technique for things like ._pyro_backward
, ._pyro_dims
, and the .unconstrained
weakref in to handle constraints.) Then you can implement a RiemannianAdam
or something that checks if hasattr(param, '_pyro_metric')
and if present use it as a preconditioner for the gradient (or something). WDYT?
1 Like
Great, thank you. This is very helpful.
@pyroman, this is a fantastic idea and would love to see this implemented. Have you made any progress?
@nmancuso I was interested in testing natural gradients for a simple generative model I had, but that did not work well so I gave up on trying to implement the elegant solution fritzo suggested. You can hack natural gradients into SVI with stochastic gradient descent (SGD) optimization. I only did a quick hack where I read the parameter names, values, and gradients during optimisation, compute from values corresponding inverse information metric, and modify the gradients accordingly. For small problem this is an easy way to test if natural gradients help for your use case.