Hello, any tips/best practices on how to implement a hierarchical VAE with conditioning structure similar to this paper?
I’m confused about how to best implement the typical pyro model/guide setup since the prior/posterior is coupled at each level of the hierarchy, and the same top-down dependency structure is used in the inference and generative model, i.e. prior/posterior generate latent variables in the same order.
I think this means I can’t get away with just having a for loop over each stochastic layer in the guide for the inference model, and another one in the model for the generative model but maybe I’m mistaken.
something like this is tricky to do (especially efficiently) using vanilla pyro given the default computational flow which is: i) run guide; ii) run model
you’d probably need to follow the kind of approach taken here
Thank you for taking the time to respond! As far as it can tell your code corresponds to a vanilla hierarchical VAE conditioning structure, what I’m looking for is something more along these lines:
I see, so you want to save some computations here. Another way to think of using prior information in the guide is to do all computations in the guide, including the computations for priors, then reuse them in the model. For example,
def model():
x = pyro.sample("x", dist.Normal(0, 1).mask(False))
...
def guide():
x = ... # any complicated computation
pyro.sample("x", dist.Delta(x))
In ladder VAE, x is loc/scale parameter of the prior of z. Then you can proceed with the template in my last comment. In some sense, Ladder VAE is a type of hierarchical VAE (with improper priors for loc, scale parameters and their posteriors follow Dirac distribution (i.e. follow some deterministic functions of upstream parameters))
Apologies if I’ve misunderstood you, I’m still getting to grips with Pyro: So you can perform the prior related computations in the guide and save them to later be retrieved in the model? If that’s the case, wouldn’t I run into excessive memory consumption issues if I have more than a few layers? Also, am I right in saying your approach would be equivalent to saving the loc/scale parameters of the prior like:
def model():
for i in enumerate(layers):
loc, scale = self.pz[i]
...
def guide():
self.pz = {}
for i in enumerate(layers):
p_loc, p_scale = ... # complicated computation
self.pz[i] = (p_loc, p_scale)
In which case I would be restricted memory wise again. Let me know if you think there’s a more efficient way of doing this in Pyro without changing the “ladder” VAE conditioning structure, or if I’ve misunderstood you altogether, maybe a different code snippet would help .
pyro computes the ELBO in one go and does not do so greedily. so yes this may lead to increased memory requirements in some cases. to avoid this you would in effect need to construct your own ELBO object or the like. depending on what your precise goals are, this may or may not be worth the trouble