I am trying to compose a Normalizing Flow (Sylvester) with a bijective (Tanh) transform to run the following code:
dim = 2
mu = torch.zeros(dim)
log_std = torch.zeros(dim)
transforms = [distributions.transforms.SylvesterFlow(input_dim=dim),
distributions.transforms.TanhTransforms()]
distributions.TransformedDistribution(distributions.Normal(mu,log_std.exp()), transforms)
but get the following:
~/sandbox/pyro_fork/pyro/distributions/transforms/sylvester.py in _inverse(self, y)
127 """
128
--> 129 raise KeyError("SylvesterFlow expected to find key in intermediates cache but didn't")
130
131 def log_abs_det_jacobian(self, x, y):
KeyError: "SylvesterFlow expected to find key in intermediates cache but didn't"
From what I understand, only Householder and IAF have their analytical inverse coded, which is causing the error.
AFAIK, the inversion happens by looking up the value y in cache, and if the corresponding x is found, it is returned. However, since Tanh is bijective, shouldn’t this cause no trouble, as the value y=Tanh(NF(x)) will be cached and returned?