 Inverting a conditional transform in pyro

I was following the pyro tutorial for conditional flows and created the following conditional flow for my application

import pyro
import pyro.distributions as dist
import pyro.distributions.transforms as T

steps = 2000

base_dist = dist.Normal(torch.zeros(3), torch.ones(3))
x_transform = T.conditional_spline(3, context_dim=3)
dist_x_given_z = dist.ConditionalTransformedDistribution(base_dist, [x_transform])

modules = torch.nn.ModuleList([x_transform])

for step in range(steps):
ln_p_x_given_z = dist_x_given_z.condition(z).log_prob(x)
loss = -(ln_p_x_given_z).mean()
loss.backward()
optimizer.step()
dist_x_given_z.clear_cache()

if step % 500 == 0:
print('step: {}, loss: {}'.format(step, loss.item()))

The quantity I am now interested in is the following:

Given x_0, z_0 and a conditional flow dist_x_given_z trained as above, I want to invert the flow to find the value n_0 from the base_dist which would have given x_0 when conditioned on z_0 when flowing through dist_x_given_z.

I guess this amounts to inverting x_transform. How can I do that? (I assume this is possible because the transformation is invertible?).

Hi @kirtan, x_transform returns a regular torch.distributions.Transform from its condition method, so you can just use the .inv property of this value in the usual way:

x_given_z_transform = x_transform.condition(z_0)
n_0 = x_given_z_transform.inv(x_0)
1 Like

Ok, I see. Thanks, that’s perfect!