Difficulty sampling from conditional distribution (normalizing flows)

I’m experiencing difficulties when sampling from a conditional transformed distribution, aiming to broadly replicate Pyro’s Normalizing Flows tutorial for higher-dimensional distributions.

Specifically, using toy data, I would like to sample form a conditional distribution where the dimension of the conditioning variable is M=1 and dimension of the output variable N>1 (e.g. N=2 for starters).

Here’s code for a minimal example of what I’m aiming to achieve:

Import modules

import torch
import pyro.distributions as dist
import pyro.distributions.transforms as T

Build conditional transformed distribution

dist_base = dist.Normal(torch.zeros(1), torch.ones(1))
x1_transform = T.spline(1)
dist_x1 = dist.TransformedDistribution(dist_base, [x1_transform])
x2_transform = T.conditional_spline(2, context_dim=1) # note the specified dimension of output N=2 and context/conditioning variable M=1
dist_x2_given_x1 = dist.ConditionalTransformedDistribution(dist_base, [x2_transform])

Sampling

dist_x2_given_x1.condition(torch.ones(1)).sample()

Sampling from this distribution as just above returns the following error (apologies for the poor formatting):


IndexError Traceback (most recent call last)
in ()
1 ## Sampling
----> 2 dist_x2_given_x1.condition(torch.ones(1)).sample()

4 frames
/usr/local/lib/python3.7/dist-packages/pyro/distributions/transforms/spline.py in _monotonic_rational_spline(inputs, widths, heights, derivatives, lambdas, inverse, bound, min_bin_width, min_bin_height, min_derivative, min_lambda, eps)
293
294 # Apply the identity function outside the bounding box
→ 295 outputs[outside_interval_mask] = inputs[outside_interval_mask]
296 logabsdet[outside_interval_mask] = 0.0
297 return outputs, logabsdet

IndexError: The shape of the mask [1] at index 0 does not match the shape of the indexed tensor [2] at index 0

On the other hand, computing log probabilities works fine using the following command:

Scoring

dist_x2_given_x1.condition(torch.ones(1)).log_prob(torch.rand(2))

Perhaps I’m overlooking something fairly obvious, but could not get around the issue.

Thank you dearly

PS:

Setting M=N (e.g. M=N=2) works fine for both sampling and scoring using the following code:

Build conditional transformed distribution

dist_base = dist.Normal(torch.zeros(2), torch.ones(2))
x1_transform = T.spline(2)
dist_x1 = dist.TransformedDistribution(dist_base, [x1_transform])
x2_transform = T.conditional_spline(2, context_dim=2)
dist_x2_given_x1 = dist.ConditionalTransformedDistribution(dist_base, [x2_transform])

Scoring

dist_x2_given_x1.condition(torch.ones(2)).log_prob(torch.rand(2))

Sampling

dist_x2_given_x1.condition(torch.ones(2)).sample()

i’m not very familiar with the normalizing flow apis but aren’t you applying a 2-d transform to a 1-d base distribution? normalizing flows are bijections and can’t “change” dimension

cc @stefanwebb

Thank you for your timely response.

My thinking with normalizing flows using conditional distributions follows the principle that p(x1,x2,c) = p(x1,x2|c)p(c) where x1 and x2 are two “output” variables (or rather, [x1,x2] is a two-dimensional output variable) and c is the conditioning variable. My understanding is that c conditions the estimation of parameters related to the bijectors via hypernets. Thus, a one-dimensional variable can condition a two-dimensional one (for instance), even within the context of normalizing flows (I think).

Perhaps my reasoning is flawed, in which case I look forward to being corrected.

yes the dimension of the thing you condition on is unconstrained but afaik your x2_transform assumes dimension 2 but you’re applying it to a 1-dimensional base distribution