# Understanding conditional flows

The below code is from the repository Link. They claim to use conditional normalizing flows to infer the exogenous variable (ε). As I am new to normalizing flows and pyro both, I find it very hard to get the intuition of what’s happening in the code.
Can someone explain mathematically what each code block is doing?

``````class BaseFlowSEM(BaseSEM):
def __init__(self, num_scales: int = 4, flows_per_scale: int = 2, hidden_channels: int = 256,
use_actnorm: bool = False, **kwargs):
super().__init__(**kwargs)

self.num_scales = num_scales
self.flows_per_scale = flows_per_scale
self.hidden_channels = hidden_channels
self.use_actnorm = use_actnorm

# priors

@pyro_method
def infer(self, **obs):
return self.infer_exogeneous(**obs)

@pyro_method
def counterfactual(self, obs: Mapping, condition: Mapping = None):
_required_data = ('x', 'thickness', 'intensity')
assert set(obs.keys()) == set(_required_data)

exogeneous = self.infer(**obs)

counter = pyro.poutine.do(pyro.poutine.condition(self.sample_scm, data=exogeneous), data=condition)(obs['x'].shape[0])
return {k: v for k, v in zip(('x', 'thickness', 'intensity'), counter)}

@classmethod

parser.add_argument('--num_scales', default=4, type=int, help="number of scales (default: %(default)s)")
parser.add_argument('--flows_per_scale', default=10, type=int, help="number of flows per scale (default: %(default)s)")
parser.add_argument('--hidden_channels', default=256, type=int, help="number of hidden channels in convnet (default: %(default)s)")
parser.add_argument('--use_actnorm', default=False, action='store_true', help="whether to use activation norm (default: %(default)s)")

return parser

class NormalisingFlowsExperiment(BaseCovariateExperiment):
def __init__(self, hparams, pyro_model: BaseSEM):
hparams.latent_dim = 32 * 32

super().__init__(hparams, pyro_model)

def configure_optimizers(self):
thickness_params = self.pyro_model.thickness_flow_components.parameters()
intensity_params = self.pyro_model.intensity_flow_components.parameters()

x_params = self.pyro_model.trans_modules.parameters()

{'params': x_params, 'lr': self.hparams.lr},
{'params': thickness_params, 'lr': self.hparams.pgm_lr},
{'params': intensity_params, 'lr': self.hparams.pgm_lr},

def prepare_data(self):
super().prepare_data()

self.z_range = self.z_range.reshape((9, 1, 32, 32))

def get_logprobs(self, **obs):
_required_data = ('x', 'thickness', 'intensity')
assert set(obs.keys()) == set(_required_data)

cond_model = pyro.condition(self.pyro_model.sample, data=obs)
model_trace = pyro.poutine.trace(cond_model).get_trace(obs['x'].shape[0])
model_trace.compute_log_prob()

log_probs = {}
nats_per_dim = {}
for name, site in model_trace.nodes.items():
if site["type"] == "sample" and site["is_observed"]:
log_probs[name] = site["log_prob"].mean()
log_prob_shape = site["log_prob"].shape
value_shape = site["value"].shape
if len(log_prob_shape) < len(value_shape):
dims = np.prod(value_shape[len(log_prob_shape):])
else:
dims = 1.
nats_per_dim[name] = -site["log_prob"].mean() / dims
if self.hparams.validate:
print(f'at site {name} with dim {dims} and nats: {nats_per_dim[name]} and logprob: {log_probs[name]}')
if torch.any(torch.isnan(nats_per_dim[name])):
raise ValueError('got nan')

return log_probs, nats_per_dim

def prep_batch(self, batch):
x = batch['image'].float()
thickness = batch['thickness'].unsqueeze(1).float()
intensity = batch['intensity'].unsqueeze(1).float()

x = torch.nn.functional.pad(x, (2, 2, 2, 2))
x += torch.rand_like(x)

x = x.reshape(-1, 1, 32, 32)

return {'x': x, 'thickness': thickness, 'intensity': intensity}

def training_step(self, batch, batch_idx):
batch = self.prep_batch(batch)

log_probs, nats_per_dim = self.get_logprobs(**batch)
loss = torch.stack(tuple(nats_per_dim.values())).sum()

if torch.isnan(loss):
raise ValueError('loss went to nan')

lls = {(f'train/log p({k})'): v for k, v in log_probs.items()}
nats_per_dim = {('train/' + k + '_nats_per_dim'): v for k, v in nats_per_dim.items()}

tensorboard_logs = {'train/loss': loss, **nats_per_dim, **lls}

self.log_dict(tensorboard_logs)

return loss

def validation_step(self, batch, batch_idx):
batch = self.prep_batch(batch)

log_probs, nats_per_dim = self.get_logprobs(**batch)
loss = torch.stack(tuple(nats_per_dim.values())).sum()

lls = {(f'train/log p({k})'): v for k, v in log_probs.items()}
nats_per_dim = {(k + '_nats_per_dim'): v for k, v in nats_per_dim.items()}

return {'loss': loss, **lls, **nats_per_dim}

def test_step(self, batch, batch_idx):
batch = self.prep_batch(batch)

log_probs, nats_per_dim = self.get_logprobs(**batch)
loss = torch.stack(tuple(nats_per_dim.values())).sum()

lls = {(f'train/log p({k})'): v for k, v in log_probs.items()}
nats_per_dim = {(k + '_nats_per_dim'): v for k, v in nats_per_dim.items()}

return {'loss': loss, **lls, **nats_per_dim}

EXPERIMENT_REGISTRY[NormalisingFlowsExperiment.__name__] = NormalisingFlowsExperiment

class ConditionalFlowSEM(BaseFlowSEM):
def __init__(self, use_affine_ex: bool = True, **kwargs):
super().__init__(**kwargs)
self.use_affine_ex = use_affine_ex

# decoder parts

# Flow for modelling t Gamma
self.thickness_flow_components = ComposeTransformModule([Spline(1)])
self.thickness_flow_constraint_transforms = ComposeTransform([self.thickness_flow_lognorm, ExpTransform()])
self.thickness_flow_transforms = ComposeTransform([self.thickness_flow_components, self.thickness_flow_constraint_transforms])

# affine flow for s normal
intensity_net = DenseNN(1, [1], param_dims=[1, 1], nonlinearity=torch.nn.Identity())
self.intensity_flow_components = ConditionalAffineTransform(context_nn=intensity_net, event_dim=0)
self.intensity_flow_constraint_transforms = ComposeTransform([SigmoidTransform(), self.intensity_flow_norm])
self.intensity_flow_transforms = [self.intensity_flow_components, self.intensity_flow_constraint_transforms]
# build flow as s_affine_w * t * e_s + b -> depends on t though

# realnvp or so for x
self._build_image_flow()

def _build_image_flow(self):

self.trans_modules = ComposeTransformModule([])

self.x_transforms = []

self.x_transforms += [self._get_preprocess_transforms()]

c = 1
for _ in range(self.num_scales):
self.x_transforms.append(SqueezeTransform())
c *= 4

for _ in range(self.flows_per_scale):
if self.use_actnorm:
actnorm = ActNorm(c)
self.trans_modules.append(actnorm)
self.x_transforms.append(actnorm)

gcp = GeneralizedChannelPermute(channels=c)
self.trans_modules.append(gcp)
self.x_transforms.append(gcp)

self.x_transforms.append(TransposeTransform(torch.tensor((1, 2, 0))))

ac = ConditionalAffineCoupling(c // 2, BasicFlowConvNet(c // 2, self.hidden_channels, (c // 2, c // 2), 2))
self.trans_modules.append(ac)
self.x_transforms.append(ac)

self.x_transforms.append(TransposeTransform(torch.tensor((2, 0, 1))))

gcp = GeneralizedChannelPermute(channels=c)
self.trans_modules.append(gcp)
self.x_transforms.append(gcp)

self.x_transforms += [
ReshapeTransform((4**self.num_scales, 32 // 2**self.num_scales, 32 // 2**self.num_scales), (1, 32, 32))
]

if self.use_affine_ex:
affine_net = DenseNN(2, [16, 16], param_dims=[1, 1])
affine_trans = ConditionalAffineTransform(context_nn=affine_net, event_dim=3)

self.trans_modules.append(affine_trans)
self.x_transforms.append(affine_trans)

@pyro_method
def pgm_model(self):
thickness_base_dist = Normal(self.thickness_base_loc, self.thickness_base_scale).to_event(1)
thickness_dist = TransformedDistribution(thickness_base_dist, self.thickness_flow_transforms)

thickness = pyro.sample('thickness', thickness_dist)
thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
# pseudo call to thickness_flow_transforms to register with pyro
_ = self.thickness_flow_components

intensity_base_dist = Normal(self.intensity_base_loc, self.intensity_base_scale).to_event(1)
intensity_dist = ConditionalTransformedDistribution(intensity_base_dist, self.intensity_flow_transforms).condition(thickness_)

intensity = pyro.sample('intensity', intensity_dist)
# pseudo call to w_flow_transforms to register with pyro
_ = self.intensity_flow_components

return thickness, intensity

@pyro_method
def model(self):
thickness, intensity = self.pgm_model()

thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
intensity_ = self.intensity_flow_norm.inv(intensity)

context = torch.cat([thickness_, intensity_], 1)

x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
cond_x_transforms = ComposeTransform(ConditionalTransformedDistribution(x_base_dist, self.x_transforms).condition(context).transforms).inv
cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms)

x = pyro.sample('x', cond_x_dist)

return x, thickness, intensity

@pyro_method
def infer_thickness_base(self, thickness):
return self.thickness_flow_transforms.inv(thickness)

@pyro_method
def infer_intensity_base(self, thickness, intensity):
intensity_base_dist = Normal(self.intensity_base_loc, self.intensity_base_scale)

thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
cond_intensity_transforms = ComposeTransform(
ConditionalTransformedDistribution(intensity_base_dist, self.intensity_flow_transforms).condition(thickness_).transforms)
return cond_intensity_transforms.inv(intensity)

@pyro_method
def infer_x_base(self, thickness, intensity, x):
x_base_dist = Normal(self.x_base_loc, self.x_base_scale)

thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
intensity_ = self.intensity_flow_norm.inv(intensity)

context = torch.cat([thickness_, intensity_], 1)
cond_x_transforms = ComposeTransform(ConditionalTransformedDistribution(x_base_dist, self.x_transforms).condition(context).transforms)
return cond_x_transforms(x)

@classmethod