Using the MCMC samplers on a PyTorch module

Hello,

I have used PyTorch to implement a fairly complicated epidemiological model. My idea is to use its differentiability to fit the input parameters using HMC or similar. I am not sure how to make this work with Pyro. So far my attempt looks like this:

def pyro_model(true_data):
  model = Model(...) # Pytorch model
    
  # Sample values from parameter priors
  priors = ["parameter_1", "parameter_2"]    
  for key in priors:
      value = pyro.sample(key, ...)
      samples[key] = value
    
  # Then I set the model parameters to the sampled values
  with torch.no_grad():
      state_dict = model.state_dict()
      for key in priors:
          value = samples[key]
          state_dict[key].copy_(value)
    
  # run the model
  results = model()
  # compare with data
  pyro.sample(..., obs=true_data)

I then use this pyro_model to initialize an MCMC chain as detailed in the docs. The sampling works but it does not seem to behave well, step_size goes very quickly to 0 or the chain gets stuck, which makes me thing that some of the internals are not working properly. I have tested that I’m able to recover the original parameters with gradient descent, so I’m confident that the gradients are calculated correctly. Similarly, I’ve also successfully computed posteriors using nested sampling algortihms such as pymultinest, but I can’t seem to get the samplers form Pyro to work well.

One thing that I’m not sure I’m considering adequately is the fact that my PyTorch model has a lot of internal randomness, for instance, sampling the number of infected people each day. The way I do this sampling is importing a distribution from pyro, and then doing distribution.rsample(). Could this be the cause of my problems?

Thank you!

yes this is generally wrong. all randomness needs to be registered with a sample statement.

note that just because a problem is amenable to gradient descent doesn’t mean it will be amenable to hmc. inferring a posterior distribution is much more difficult than finding a point estimate

Hi @martinjankowiak , thank you for your quick answer.

yes this is generally wrong. all randomness needs to be registered with a sample statement.

I see, is there anyway to treat the PyTorch model as a black box stochastic simulator that returns outputs and derivatives respect to its parameters? Or is the only way to go into the PyTorch source code and replace all the dist.rsample() statements to pyro.sample("name", dist) ?

note that just because a problem is amenable to gradient descent doesn’t mean it will be amenable to hmc. inferring a posterior distribution is much more difficult than finding a point estimate

Absolutely, my comment was more about the gradients being computed correctly.

even a black box simulator needs priors defined if you want to do bayesian inference. so yes you need to use pyro.sample everywhere (or possibly PyroSample)

I see, so just to make sure I understand this properly. If I have a model that looks like:

input_parameters -------------> stochastic simulator ----------------> results

and I want to do HMC. It is not enough to just sample the input_parameters from priors (using pyro.sample()) and then pass the parameters to the simulator ,but I also have to replace all the sampling inside the simulator (usually done with dist.rsample()) with a pyro.sample statement?

What I don’t understand is why is this required, since I could manually write a routine that samples from the priors, runs the model, gets the gradients with autograd and samples with an HMC algorithm.

you appear to be maintaining a distinction between randomness in input parameters and randomness in internal random variables. however HMC needs a (deterministic) log joint density that encompasses all sources of randomness. whether an internal random variable is more likely to have fluctuated up or down is something that explicitly needs to be reasoned about when computing the posterior

1 Like

Ok, I understand, thanks!