Creating Custom `TransformedDistribution` to apply element-wise

Hi there,

Thanks for the awesome library and community.

I have a relatively straightforward question.

Here is a snippet from the Normalizing Flow tutorial:

base_dist = dist.Normal(torch.zeros(2), torch.ones(2))
spline_transform = T.Spline(2, count_bins=16)
flow_dist = dist.TransformedDistribution(base_dist, [spline_transform])

Samples from flow.dist (x_1, x_2) are two-dimensional. I want to transform flow_dist to another distribution that applies different functions to each dimension of flow_dist element-wise. For example, (\log x_1, \exp x_2).

Thanks in advance and sorry for the silly question.

This seems like a good use for CatTransform or StackTransform .