Inverting a conditional transform in pyro

I was following the pyro tutorial for conditional flows and created the following conditional flow for my application

    import pyro
    import pyro.distributions as dist
    import pyro.distributions.transforms as T

    steps = 2000

    base_dist = dist.Normal(torch.zeros(3), torch.ones(3))
    x_transform = T.conditional_spline(3, context_dim=3)
    dist_x_given_z = dist.ConditionalTransformedDistribution(base_dist, [x_transform])

    modules = torch.nn.ModuleList([x_transform])
    optimizer = torch.optim.Adam(modules.parameters(), lr=3e-3)

    for step in range(steps):
        optimizer.zero_grad()
        ln_p_x_given_z = dist_x_given_z.condition(z).log_prob(x)
        loss = -(ln_p_x_given_z).mean()
        loss.backward()
        optimizer.step()
        dist_x_given_z.clear_cache()

        if step % 500 == 0:
            print('step: {}, loss: {}'.format(step, loss.item())) 

The quantity I am now interested in is the following:

Given x_0, z_0 and a conditional flow dist_x_given_z trained as above, I want to invert the flow to find the value n_0 from the base_dist which would have given x_0 when conditioned on z_0 when flowing through dist_x_given_z.

I guess this amounts to inverting x_transform. How can I do that? (I assume this is possible because the transformation is invertible?).

cc @stefanwebb

Hi @kirtan, x_transform returns a regular torch.distributions.Transform from its condition method, so you can just use the .inv property of this value in the usual way:

x_given_z_transform = x_transform.condition(z_0)
n_0 = x_given_z_transform.inv(x_0)
1 Like

Ok, I see. Thanks, that’s perfect!