I’m considering a model with log joint p(\theta,\phi,x), where x are data and \theta and \phi are parameters. Suppose that we factorise the variational approximation to the posterior as q(\theta,\phi)=q_\theta(\theta)q_\phi(\phi), and we have efficient mean-field variational inference update rules for q_\theta but not for q_\phi. Ideally, I’d like to employ pyro for learning q_\phi using SVI (conditioned on x and the current q_\theta) but update q_\theta “manually” because I have the update rules in closed form (conditioned on x and the current q_\phi). Is that feasible within the current (num)pyro framework?
this is probably easier to do in pyro because pytorch makes things more hackable. basically you will want to declare variational_phi
using pyro.param
and variational_theta
as torch tensors (so that pyro doesn’t try to update them) and then interleave svi.step()
(which will update variational_phi
) with your custom update rule for variational_theta
. details will then depend on whether you’re doing data subsampling, on whether phi
and theta
are global or local latent variables, etc
1 Like