In the documentation I only found ComposeTransformModule
to compose several transforms. Considering the following list of transforms
transforms = [
Permute(torch.randperm(2, dtype=torch.long)),
SplineAutoregressive(2,
AutoRegressiveNN(2, [40],
param_dims=[8,8,8-1,8]),
order='linear',count_bins=8)
]
the command
import pyro.distributions as dist
transforms = dist.ComposeTransformModule(transforms)
would yield
TypeError: pyro.distributions.transforms.permute.Permute is not a Module subclass
Is there another way to compose such transforms in way so that I can do the following?
y = transforms(x) # forward path
z = transforms.inv(y) # backward path
Thanks for any comment.