Compose Transforms

In the documentation I only found ComposeTransformModule to compose several transforms. Considering the following list of transforms

transforms = [
  Permute(torch.randperm(2, dtype=torch.long)),
  SplineAutoregressive(2,
              AutoRegressiveNN(2, [40],
                               param_dims=[8,8,8-1,8]),
              order='linear',count_bins=8)
]

the command

import pyro.distributions as dist

transforms = dist.ComposeTransformModule(transforms)

would yield

TypeError: pyro.distributions.transforms.permute.Permute is not a Module subclass

Is there another way to compose such transforms in way so that I can do the following?

y = transforms(x)         # forward path
z = transforms.inv(y)     # backward path

Thanks for any comment.

Hi @DerJFP, you could try to use the torch.distributions.transforms.ComposeTransform which should be available as a Pyro import

import pyro.distributions as dist
help(dist.transforms.ComposeTransform)

You should then be able to

transform = ComposeTransform([
    Permute(torch.randperm(2, dtype=torch.long)),
    SplineAutoregressive(
        2,
        AutoRegressiveNN(2, [40], param_dims=[8,8,8-1,8]),
        order='linear',
        count_bins=8,
    )
])
1 Like