I’m experiencing difficulties when sampling from a conditional transformed distribution, aiming to broadly replicate Pyro’s Normalizing Flows tutorial for higher-dimensional distributions.

Specifically, using toy data, I would like to sample form a conditional distribution where the dimension of the conditioning variable is M=1 and dimension of the output variable N>1 (e.g. N=2 for starters).

Here’s code for a minimal example of what I’m aiming to achieve:

## Import modules

import torch

import pyro.distributions as dist

import pyro.distributions.transforms as T

## Build conditional transformed distribution

dist_base = dist.Normal(torch.zeros(1), torch.ones(1))

x1_transform = T.spline(1)

dist_x1 = dist.TransformedDistribution(dist_base, [x1_transform])

x2_transform = T.conditional_spline(2, context_dim=1)# note the specified dimension of output N=2 and context/conditioning variable M=1

dist_x2_given_x1 = dist.ConditionalTransformedDistribution(dist_base, [x2_transform])

## Sampling

dist_x2_given_x1.condition(torch.ones(1)).sample()

Sampling from this distribution as just above returns the following error (apologies for the poor formatting):

IndexError Traceback (most recent call last)

in ()

1 ## Sampling

----> 2 dist_x2_given_x1.condition(torch.ones(1)).sample()4 frames

/usr/local/lib/python3.7/dist-packages/pyro/distributions/transforms/spline.py in _monotonic_rational_spline(inputs, widths, heights, derivatives, lambdas, inverse, bound, min_bin_width, min_bin_height, min_derivative, min_lambda, eps)

293

294 # Apply the identity function outside the bounding box

→ 295 outputs[outside_interval_mask] = inputs[outside_interval_mask]

296 logabsdet[outside_interval_mask] = 0.0

297 return outputs, logabsdetIndexError: The shape of the mask [1] at index 0 does not match the shape of the indexed tensor [2] at index 0

On the other hand, computing log probabilities works fine using the following command:

## Scoring

dist_x2_given_x1.condition(torch.ones(1)).log_prob(torch.rand(2))

Perhaps I’m overlooking something fairly obvious, but could not get around the issue.

Thank you dearly

PS:

Setting M=N (e.g. M=N=2) works fine for both sampling and scoring using the following code:

## Build conditional transformed distribution

dist_base = dist.Normal(torch.zeros(2), torch.ones(2))

x1_transform = T.spline(2)

dist_x1 = dist.TransformedDistribution(dist_base, [x1_transform])

x2_transform = T.conditional_spline(2, context_dim=2)

dist_x2_given_x1 = dist.ConditionalTransformedDistribution(dist_base, [x2_transform])## Scoring

dist_x2_given_x1.condition(torch.ones(2)).log_prob(torch.rand(2))

## Sampling

dist_x2_given_x1.condition(torch.ones(2)).sample()