I would like to implement the natural gradient ascent over ELBO, however I have hard time figuring out if I should use existing classes in pyro.optim and pyro.infer or implement a completely new inference class.
Algorithm should be straightforward to implement
iterate over sample sites of the guide and for each site compute inverse fisher information matrix G(param) dependent on the current parameter values.
compute ELBO and corresponding stochastic gradients for each parameter, SG(param).
Update parameters using natural gradient ascent: param -> param + lr*G(param)*SG(param)
My current idea is to use SGD optimiser (without momentum) with Trace_ELBO, get gradients after each update step, and modify parameter values given natural gradients instead of canonical gradients provided by the SGD optimiser.
I would appreciate an expert opinion on this. Do you think that this would be a reasonable good implementation or do you see a more efficient way for implementing the algorithm?