Is it possible to use the transforms.AffineCoupling
, which takes in a DenseNN
and apply the location-scale transformation on all components of the previous random variable.
E.g., to declare a 2d Gaussian with a learnable mean, I would think that one would do
split_dim = 0
dim = 2
hypernet = DenseNN(split_dim, [10*dim], [dim-split_dim, dim-split_dim])
affine = distributions.transforms.AffineCoupling(split_dim,hypernet)
td = distributions.TransformedDistribution(distributions.Normal(torch.zeros(dim),torch.ones(dim)), [affine])
pyro.register("affine",affine)
However, this only works when split_dim >= 1. Am I missing something, or is there another way to do this (except passing a nn.Linear and re-defining the base distribution every time)?