I have used PyTorch to implement a fairly complicated epidemiological model. My idea is to use its differentiability to fit the input parameters using HMC or similar. I am not sure how to make this work with Pyro. So far my attempt looks like this:
def pyro_model(true_data): model = Model(...) # Pytorch model # Sample values from parameter priors priors = ["parameter_1", "parameter_2"] for key in priors: value = pyro.sample(key, ...) samples[key] = value # Then I set the model parameters to the sampled values with torch.no_grad(): state_dict = model.state_dict() for key in priors: value = samples[key] state_dict[key].copy_(value) # run the model results = model() # compare with data pyro.sample(..., obs=true_data)
I then use this
pyro_model to initialize an MCMC chain as detailed in the docs. The sampling works but it does not seem to behave well,
step_size goes very quickly to 0 or the chain gets stuck, which makes me thing that some of the internals are not working properly. I have tested that I’m able to recover the original parameters with gradient descent, so I’m confident that the gradients are calculated correctly. Similarly, I’ve also successfully computed posteriors using nested sampling algortihms such as pymultinest, but I can’t seem to get the samplers form Pyro to work well.
One thing that I’m not sure I’m considering adequately is the fact that my PyTorch model has a lot of internal randomness, for instance, sampling the number of infected people each day. The way I do this sampling is importing a distribution from pyro, and then doing
distribution.rsample(). Could this be the cause of my problems?