Condition IAF flows on context input h?

I was looking at the normalizing flows module as I am hoping to use them on a project I’ve been looking at. Is it possible to pass additional inputs to the IAF layers before creating a TransformedDistribution? I was re-reading the IAF paper and they mention passing as input a context vector h to each flow besides the previous samples z_{t-1}.

From looking at the Pyro code for it, it doesn’t seem they support giving a context h to each flow. Is this actually the case, or did I just miss something? If it ISN"T supported, is there any suggested practices I should be aware of to make an IAF flow that does allow this?

So maybe something like:

#necessary imports (Gaussian, torch.nn , etc. etc.)
iaf = IAFCondition(...)
mu_0 eps_0, h = nn(X)
iaf.condition(h)
dist = Gaussian(mu_0, sigma_0)
dist = TransformedDistribution(dist, [iaf])
#the rest of my pyro goodness
 

Hi @megaloman, thanks for your interest! I’m glad to hear that there are users who would like to see conditional flows, which is something I have been working on recently.

We’ve just added the capability to represent conditional distributions and transforms (i.e. that input an additional context variable) as well as conditional MADE. See:


Currently, there is only an implementation for conditional PlanarFlow… But a conditional IAF/NAF etc. will be added to Pyro in the very near future (and I’ll do this next for you)! It is quite easy now that conditional transforms and a conditional MADE have been added to Pyro.

2 Likes

That’s exciting news, and thank you for the update! To get using it right away, I supposed I should install from source correct? Looking forward to hearing about the further updates for other flows :smiley:

Yes that’s right! Grab the dev branch and look at

for an example of how to do conditional planar flow

And you can follow work on the conditional flows by checking out my PRs: https://github.com/pyro-ppl/pyro/pulls

1 Like