I’m experiencing difficulties when sampling from a conditional transformed distribution, aiming to broadly replicate Pyro’s Normalizing Flows tutorial for higher-dimensional distributions.
Specifically, using toy data, I would like to sample form a conditional distribution where the dimension of the conditioning variable is M=1 and dimension of the output variable N>1 (e.g. N=2 for starters).
Here’s code for a minimal example of what I’m aiming to achieve:
Import modules
import torch
import pyro.distributions as dist
import pyro.distributions.transforms as T
Build conditional transformed distribution
dist_base = dist.Normal(torch.zeros(1), torch.ones(1))
x1_transform = T.spline(1)
dist_x1 = dist.TransformedDistribution(dist_base, [x1_transform])
x2_transform = T.conditional_spline(2, context_dim=1) # note the specified dimension of output N=2 and context/conditioning variable M=1
dist_x2_given_x1 = dist.ConditionalTransformedDistribution(dist_base, [x2_transform])
Sampling
dist_x2_given_x1.condition(torch.ones(1)).sample()
Sampling from this distribution as just above returns the following error (apologies for the poor formatting):
IndexError Traceback (most recent call last)
in ()
1 ## Sampling
----> 2 dist_x2_given_x1.condition(torch.ones(1)).sample()4 frames
/usr/local/lib/python3.7/dist-packages/pyro/distributions/transforms/spline.py in _monotonic_rational_spline(inputs, widths, heights, derivatives, lambdas, inverse, bound, min_bin_width, min_bin_height, min_derivative, min_lambda, eps)
293
294 # Apply the identity function outside the bounding box
→ 295 outputs[outside_interval_mask] = inputs[outside_interval_mask]
296 logabsdet[outside_interval_mask] = 0.0
297 return outputs, logabsdetIndexError: The shape of the mask [1] at index 0 does not match the shape of the indexed tensor [2] at index 0
On the other hand, computing log probabilities works fine using the following command:
Scoring
dist_x2_given_x1.condition(torch.ones(1)).log_prob(torch.rand(2))
Perhaps I’m overlooking something fairly obvious, but could not get around the issue.
Thank you dearly
PS:
Setting M=N (e.g. M=N=2) works fine for both sampling and scoring using the following code:
Build conditional transformed distribution
dist_base = dist.Normal(torch.zeros(2), torch.ones(2))
x1_transform = T.spline(2)
dist_x1 = dist.TransformedDistribution(dist_base, [x1_transform])
x2_transform = T.conditional_spline(2, context_dim=2)
dist_x2_given_x1 = dist.ConditionalTransformedDistribution(dist_base, [x2_transform])Scoring
dist_x2_given_x1.condition(torch.ones(2)).log_prob(torch.rand(2))
Sampling
dist_x2_given_x1.condition(torch.ones(2)).sample()