Implementing Uncorrected Hamiltonian Annealing / Differentiable Annealed Importance Sampling

I am interested in using UHA/DAIS [1, 2] in Pyro. I noticed that there is an implementation for NumPyro [3]. I would be willing to implement the method in Pyro, however, as I am very new to the Pyro/NumPyro ecosystem I might need some initial pointers to start.
I had a very brief look at the code in NumPyro today. I guess it should be manageable to port the NumPyro Implementation in large parts to Pyro. However, I noticed that there is no equivalent to _potential_fn in Pyro. More specifically, I am talking about the method

def log_density(x):
  x_unpack = self._unpack_latent(x)
  with numpyro.handlers.block():
      return -self._potential_fn(x_unpack)

in the method _sample_latent which is computing the log joint probability of latents and data. I am not very sure what might be the best way to implement this in Pyro.

Any help is very appreciated! Thank you in advance.

[1] Geffner, T. and Domke, J. (2021). MCMC variational inference via uncorrected Hamiltonian annealing. In Advances in Neural Information Processing Systems
[2] Zhang, G., Hsu, K., Li, J., Finn, C., and Grosse, R. (2021). Differentiable annealed importance sampling and the perils of gradient noise. In Advances in Neural Information Processing Systems.
[3] Automatic Guide Generation — NumPyro documentation

afaik the place to start would be initialize_model

although @fehiepsi might have a better suggestion. note it might be a bit complicated to get all the details right, given the complexity of transforms and the like.

also note that any pyro implementation of uha/dais is expected to be somewhat slow compared to numpyro. basically because there will be overhead from a python for loop, which jax can compile away. also pytorch isn’t great about optimizing gradients. this is one of the reasons we implemented these methods in numpyro.

Thank you for your response! Currently, I do not care about speed but have a code base that is written in PyTorch. Are you suggesting to better use NumPyro directly as the implementation of this in Pyro would be too much of a hassle?

@jzenn it’s hard for me to judge the degree to which you might find it a hassle : ) we’d be happy to have a pyro implementation so happy to help review the code and such. just trying to give you a full picture, especially w.r.t. expected speed

1 Like

I see, thanks for your guidance! I think I will try to create a first draft in the near future. I’ll let you know when I have further questions.