Transforming HMC sampler with autoguides

I realized that significant parameter degeneracies, or large differences in the variances of parameters, have a devastating effect on the performance of NUTS, even for small numbers of parameters. Rescaling parameters back to O(1) solves that problem. I’m trying to implement this now semi-automatically by using the results from a variational inference step with AutoDiagonalNormal (AutoLaplaceApproximation does not work since grid_sample does not allow to calculate the hessian). This seems to work fine. However, it seems to me that the entire approach could be generalized to more expressive guides (e.g., IAF) without too much hassle. Using the learned bijection would allow to bring the parameter space close to standard normal before feeding it into the HMC. I’m wondering if this is something that is implemented successfully in any of the probabilistic programming languages out there, or something interesting for pyro.

Hi @cweniger, that is an interesting approach and is also one of the main features for our next numpyro release (you can find some of related discussions in this paper from Google researchers or in this PR). We want to build an easy-to-use framework for that approach. Because you are also interested in this direction, would we discuss more on this thread or on github PRs about related technical details or feature requests?

Btw, we choose numpyro to explore this direction mainly because it is much faster, hence suitable for real applications. The same approach can be leveraged to Pyro (which already has a mature normalization flow module). :slight_smile:

Hi @fehiepsi, awesome to see that you guy are already implementing this approach, thanks! My models involve some heavy pytorch calculations, and the pyro part is not the bottleneck (individual evaluations might take up to a fraction of a second on GPU, with total runtimes of O(1 day)). Hence I’m not sure if moving to numpyro would help much in this case? EDITED: I should add that part of the model involves generative NNs, written/trained in pytorch, which might make the migration hard.

Hi @cweniger, I think that it is better to use Pyro in your case (though JAX/XLA can be faster than PyTorch in GPU, the framework is in alpha state so it might contain bugs which are hard to detect). We have refactored HMC in Pyro to support general potential_fn, so I believe that a transition from numpyro to pyro would be easy. :slight_smile:

Currently, we create issues Handle dynamic support · Issue #241 · pyro-ppl/numpyro · GitHub and Handle intermediate values in composed transform · Issue #242 · pyro-ppl/numpyro · GitHub to track down technical issues to make inferences involving flows more effective (to avoid recomputing intermediate terms during flow transforms). Then we can start exploring the approach in this topic in some synthesis datasets to see its effectiveness in high dimensional cases. I’ll update here the main progress and examples which we make during the way. We’d be very happy to hear feedback (would be pretty valuable for us) from you.

@cweniger I made a PR to illustrate the process of using a trained autoguide for HMC. I think that you can mimic it to do a similar thing in Pyro. We still need to do more experiments to understand more about this approach, so if you have any suggestion, please let us know. Thanks in advance! :slight_smile:

That is great, thanks a ton! I’ll look into this at some point soon, really curious how this is done. I’m particularly interested if this method helps to accelerate HMC with very high-dimensional problems (millions of parameters, partially strongly correlated). I guess in that case the IAF might be hard to train, but maybe the simple LowRangMultivariateNormal would work as guide.

@cweniger Using LowRankMVN should work too. You’ll need to implement a multivariate AffineTransform, which is currently not available in PyTorch’s distributions module yet. So when you get LowRankMVN’s loc and scale_tril tensors, you can use them to transform the parameters in HMC.

Thanks! It looks like the development focuses on numpyro right now. Will you keep developing both pyro and numpyro in parallel or should I assume that the future is numpyro? Is it possible to combine a trained pytorch DNN (a trained generative model etc) with numpyro?

We plan to develop new experimental HMC related features in NumPyro, mostly because of the execution speed improvements from using XLA that results in faster iteration. However, if there is enough interest from the community and the feature seems to be something that will be more generally useful, we will be happy to accept contributions to port it to Pyro (or, invest effort doing so ourselves).

Is it possible to combine a trained pytorch DNN (a trained generative model etc) with numpyro?

I think this is possible by dumping the weights using pickle and rewriting your DNN using JAX’s stax module, but is probably not practical (as in you won’t be saving any time by doing this). I would suggest just using JAX’s stax module to begin with instead of training your DNN using PyTorch. We are still testing this out in NumPyro, but most of the code should already be on master if you want to play around with it. Porting it to Pyro should also be fairly straightforward.

Hi @neerajprad. I started exploring the HMC from numpyro (which is way faster), and noticed that substitute, which seems to replace the previous condition, does not actually fix parameters for a model that runs through mcmc. I can only fix observed parameters via keyword arguments. Is this expected behavior?

This isn’t expected behavior. Could you share your code snippet? This is independent of mcmc, btw. e.g.

d = dist.Delta(0.)

def model():
    x = sample('x', d)
    y = substitute(lambda: sample('y', d), {'y': x})()
    return x + y

assert substitute(model, {'x': 3.})() == 6.

I think what’s missing is the is_observed command that I added below to process_message from substitute.

    def process_message(self, msg):
        if self.param_map:
            if msg['name'] in self.param_map:
                msg['value'] = self.param_map[msg['name']]
                msg['is_observed'] = True
        else:
            value = self.substitute_fn(msg)
            if value is not None:
                msg['value'] = value
                msg['is_observed'] = True

After that change the HMC works as expected.

2 Likes

My bad, I misunderstood. Your observation is correct - substitute is just a general convenience effect handler (it works on both param and sample statements) and isn’t the same as Pyro’s condition statement. In MCMC, we run inference on all unobserved sample sites, hence substitute will have no effect. We should probably add a separate condition handler for this.

Ok, thanks for the clarification. A separate condition handler would be very convenient, at least for the physical models that I’m currently looking at where I would like to see what happens when fixing various parameters, etc.

I’ll add it shortly. :slight_smile: