I was following the pyro tutorial for conditional flows and created the following conditional flow for my application
import pyro
import pyro.distributions as dist
import pyro.distributions.transforms as T
steps = 2000
base_dist = dist.Normal(torch.zeros(3), torch.ones(3))
x_transform = T.conditional_spline(3, context_dim=3)
dist_x_given_z = dist.ConditionalTransformedDistribution(base_dist, [x_transform])
modules = torch.nn.ModuleList([x_transform])
optimizer = torch.optim.Adam(modules.parameters(), lr=3e-3)
for step in range(steps):
optimizer.zero_grad()
ln_p_x_given_z = dist_x_given_z.condition(z).log_prob(x)
loss = -(ln_p_x_given_z).mean()
loss.backward()
optimizer.step()
dist_x_given_z.clear_cache()
if step % 500 == 0:
print('step: {}, loss: {}'.format(step, loss.item()))
The quantity I am now interested in is the following:
Given x_0, z_0 and a conditional flow dist_x_given_z trained as above, I want to invert the flow to find the value n_0 from the base_dist which would have given x_0 when conditioned on z_0 when flowing through dist_x_given_z.
I guess this amounts to inverting x_transform. How can I do that? (I assume this is possible because the transformation is invertible?).
cc @stefanwebb