Ladder VAE tips

Hello, any tips/best practices on how to implement a hierarchical VAE with conditioning structure similar to this paper?

I’m confused about how to best implement the typical pyro model/guide setup since the prior/posterior is coupled at each level of the hierarchy, and the same top-down dependency structure is used in the inference and generative model, i.e. prior/posterior generate latent variables in the same order.

I think this means I can’t get away with just having a for loop over each stochastic layer in the guide for the inference model, and another one in the model for the generative model but maybe I’m mistaken.

I really appreciate any feedback on this!

something like this is tricky to do (especially efficiently) using vanilla pyro given the default computational flow which is: i) run guide; ii) run model

you’d probably need to follow the kind of approach taken here

I just used ladder vae recently but maybe I was doing it wrongly. Shouldn’t the code be

def model(x):
    z2 = pyro.sample("z2", dist.Normal(0, 1))
    z1 = pyro.sample("z1", dist.Normal(mu1(z2), sigma1(z2)))
    return pyro.sample("x", dist.Normal(mu0(z1), sigma0(z1), obs=x)

def guide(x):
    # compute x -> mu1_, sigma1_ -> mu2_, sigma2_,..., then
    z2 = pyro.sample("z2", dist.Normal(mu2_(x), sigma2_(x))
    z1 = pyro.sample("z1", dist.Normal(mu1_(x), sigma1_(x))

and use TraceMeanField_ELBO objective?

Thank you for taking the time to respond! As far as it can tell your code corresponds to a vanilla hierarchical VAE conditioning structure, what I’m looking for is something more along these lines:

# Deterministic bottom-up pass (encoder extracted features)
a1 = f1(x)  
a2 = f2(a1)

# Stochastic top-down pass (recursively compute both approx. posterior & generative distribution)
e2 ~ Normal(0, 1)

e2_a2 = [e2, a2]  # concatenate
qz2 = Normal(q_mu2(e2_a2), q_sigma2(e2_a2))  # posterior
pz2 = Normal(p_mu2(e2), p_sigma2(e2))  # prior
z2 ~ qz2  # sample when training, z2 ~ pz2 when generating
e1 = e2 + h2(z2)

e1_a1 = [e1, a1] 
qz1 = Normal(q_mu1(e1_a1), q_sigma1(e1_a1))  
pz1 = Normal(p_mu1(e1), p_sigma1(e1))  
z1 ~ qz1 
x = e1 + h1(z1)

Note the dependencies between different layers.

Please let me know if you have any pointers/examples of how to implement something like this using Pyro, thank you very much!

I see, so you want to save some computations here. Another way to think of using prior information in the guide is to do all computations in the guide, including the computations for priors, then reuse them in the model. For example,

def model():
    x = pyro.sample("x", dist.Normal(0, 1).mask(False))
    ...

def guide():
    x = ...  # any complicated computation
    pyro.sample("x", dist.Delta(x))

In ladder VAE, x is loc/scale parameter of the prior of z. Then you can proceed with the template in my last comment. In some sense, Ladder VAE is a type of hierarchical VAE (with improper priors for loc, scale parameters and their posteriors follow Dirac distribution (i.e. follow some deterministic functions of upstream parameters))

I really appreciate the feedback!

Apologies if I’ve misunderstood you, I’m still getting to grips with Pyro: So you can perform the prior related computations in the guide and save them to later be retrieved in the model? If that’s the case, wouldn’t I run into excessive memory consumption issues if I have more than a few layers? Also, am I right in saying your approach would be equivalent to saving the loc/scale parameters of the prior like:

def model():
    for i in enumerate(layers):
        loc, scale = self.pz[i]
    ...

def guide():
    self.pz = {}
    for i in enumerate(layers):
        p_loc, p_scale = ...  # complicated computation
        self.pz[i] = (p_loc, p_scale)

In which case I would be restricted memory wise again. Let me know if you think there’s a more efficient way of doing this in Pyro without changing the “ladder” VAE conditioning structure, or if I’ve misunderstood you altogether, maybe a different code snippet would help :slight_smile: .

Many thanks!

pyro computes the ELBO in one go and does not do so greedily. so yes this may lead to increased memory requirements in some cases. to avoid this you would in effect need to construct your own ELBO object or the like. depending on what your precise goals are, this may or may not be worth the trouble