"Manual" mean field inference for subset of parameters

I’m considering a model with log joint p(\theta,\phi,x), where x are data and \theta and \phi are parameters. Suppose that we factorise the variational approximation to the posterior as q(\theta,\phi)=q_\theta(\theta)q_\phi(\phi), and we have efficient mean-field variational inference update rules for q_\theta but not for q_\phi. Ideally, I’d like to employ pyro for learning q_\phi using SVI (conditioned on x and the current q_\theta) but update q_\theta “manually” because I have the update rules in closed form (conditioned on x and the current q_\phi). Is that feasible within the current (num)pyro framework?

this is probably easier to do in pyro because pytorch makes things more hackable. basically you will want to declare variational_phi using pyro.param and variational_theta as torch tensors (so that pyro doesn’t try to update them) and then interleave svi.step() (which will update variational_phi) with your custom update rule for variational_theta. details will then depend on whether you’re doing data subsampling, on whether phi and theta are global or local latent variables, etc

1 Like