I’m considering a model with log joint p(\theta,\phi,x), where x are data and \theta and \phi are parameters. Suppose that we factorise the variational approximation to the posterior as q(\theta,\phi)=q_\theta(\theta)q_\phi(\phi), and we have efficient mean-field variational inference update rules for q_\theta but not for q_\phi. Ideally, I’d like to employ pyro for learning q_\phi using SVI (conditioned on x and the current q_\theta) but update q_\theta “manually” because I have the update rules in closed form (conditioned on x and the current q_\phi). Is that feasible within the current (num)pyro framework?

this is probably easier to do in pyro because pytorch makes things more hackable. basically you will want to declare `variational_phi`

using `pyro.param`

and `variational_theta`

as torch tensors (so that pyro doesn’t try to update them) and then interleave `svi.step()`

(which will update `variational_phi`

) with your custom update rule for `variational_theta`

. details will then depend on whether you’re doing data subsampling, on whether `phi`

and `theta`

are global or local latent variables, etc

1 Like