Training a Pyro Model that has both Bayesian and Frequentist compoenets

I am trying to train my Pyro neural network model.

To describe my model, I converted the multiple_choice_head portion of the original Frequentist PyTorch model into a Pyro Bayesian network, and left the rest parts of the same PyTorch model in its original Frequentist form (that is, I only assigned prior distributions to the multiple_choice_head portion of my Pyro model).

I am wondering, when I train my Pyro model in the following way, whether the training process will only affect the parameter of the guide distribution that applies to the multiple_choice_head portion of my Pyro model, and leave rest of the Frequentist weights untouched?:

optimizer_args = {'lr': 0.00013}
optimizer_3 = torch.optim.Adam
scheduler_args = {'optimizer': optimizer_3,
                          'step_size' : 1, 'gamma' : 1.5, 
                          'optim_args' : optimizer_args}
scheduler_3 = pyro.optim.StepLR(scheduler_args)

svi_diag_normal = SVI(model, guide_diag_normal, scheduler_3,

train(model, svi_diag_normal...)

Thank you,

Hi @h56cho, to ensure your some of your parameters are learned and others are untouched, you need to make sure that you call pyro.param() on only the weights you want to be learned. For example, in your case make sure you aren’t calling pyro.param() on any weights in the multiple_choice_head portion of your Pyro model. You can examine what weights are being learned by running a single learning step and then printing pyro.get_param_store().keys().

In my case I actually applied module.PyroSample() to all the weights in the multiple_choice_head portion, since I am interested in changing the weights for the multiple_choice_head only (that is, I only applied priors to multiple_choice_head portion of my model via module.PyroSample()). I did not apply module.PyroSample() to the rest (frequentist) parts of my model, because I wanted to avoid changing those frequentist weights.

Would this change the parameters of the guide distribution for the multiple_choice_head, while maintaining the frequentist weights to be untouched?

Thank you,

@h56cho it sounds like you’ve gotten things working by only applying PyroSample() to desired parameters. To be sure though, I’d recommend examining pyro.get_param_store().keys() after one or more training steps.

1 Like