The below code is from the repository Link. They claim to use conditional normalizing flows to infer the exogenous variable (ε). As I am new to normalizing flows and pyro both, I find it very hard to get the intuition of what’s happening in the code.
Can someone explain mathematically what each code block is doing?
class BaseFlowSEM(BaseSEM):
def __init__(self, num_scales: int = 4, flows_per_scale: int = 2, hidden_channels: int = 256,
use_actnorm: bool = False, **kwargs):
super().__init__(**kwargs)
self.num_scales = num_scales
self.flows_per_scale = flows_per_scale
self.hidden_channels = hidden_channels
self.use_actnorm = use_actnorm
# priors
self.register_buffer('thickness_base_loc', torch.zeros([1, ], requires_grad=False))
self.register_buffer('thickness_base_scale', torch.ones([1, ], requires_grad=False))
self.register_buffer('intensity_base_loc', torch.zeros([1, ], requires_grad=False))
self.register_buffer('intensity_base_scale', torch.ones([1, ], requires_grad=False))
self.register_buffer('x_base_loc', torch.zeros([1, 32, 32], requires_grad=False))
self.register_buffer('x_base_scale', torch.ones([1, 32, 32], requires_grad=False))
@pyro_method
def infer(self, **obs):
return self.infer_exogeneous(**obs)
@pyro_method
def counterfactual(self, obs: Mapping, condition: Mapping = None):
_required_data = ('x', 'thickness', 'intensity')
assert set(obs.keys()) == set(_required_data)
exogeneous = self.infer(**obs)
counter = pyro.poutine.do(pyro.poutine.condition(self.sample_scm, data=exogeneous), data=condition)(obs['x'].shape[0])
return {k: v for k, v in zip(('x', 'thickness', 'intensity'), counter)}
@classmethod
def add_arguments(cls, parser):
parser = super().add_arguments(parser)
parser.add_argument('--num_scales', default=4, type=int, help="number of scales (default: %(default)s)")
parser.add_argument('--flows_per_scale', default=10, type=int, help="number of flows per scale (default: %(default)s)")
parser.add_argument('--hidden_channels', default=256, type=int, help="number of hidden channels in convnet (default: %(default)s)")
parser.add_argument('--use_actnorm', default=False, action='store_true', help="whether to use activation norm (default: %(default)s)")
return parser
class NormalisingFlowsExperiment(BaseCovariateExperiment):
def __init__(self, hparams, pyro_model: BaseSEM):
hparams.latent_dim = 32 * 32
super().__init__(hparams, pyro_model)
def configure_optimizers(self):
thickness_params = self.pyro_model.thickness_flow_components.parameters()
intensity_params = self.pyro_model.intensity_flow_components.parameters()
x_params = self.pyro_model.trans_modules.parameters()
return torch.optim.Adam([
{'params': x_params, 'lr': self.hparams.lr},
{'params': thickness_params, 'lr': self.hparams.pgm_lr},
{'params': intensity_params, 'lr': self.hparams.pgm_lr},
], lr=self.hparams.lr, eps=1e-5, amsgrad=self.hparams.use_amsgrad, weight_decay=self.hparams.l2)
def prepare_data(self):
super().prepare_data()
self.z_range = self.z_range.reshape((9, 1, 32, 32))
def get_logprobs(self, **obs):
_required_data = ('x', 'thickness', 'intensity')
assert set(obs.keys()) == set(_required_data)
cond_model = pyro.condition(self.pyro_model.sample, data=obs)
model_trace = pyro.poutine.trace(cond_model).get_trace(obs['x'].shape[0])
model_trace.compute_log_prob()
log_probs = {}
nats_per_dim = {}
for name, site in model_trace.nodes.items():
if site["type"] == "sample" and site["is_observed"]:
log_probs[name] = site["log_prob"].mean()
log_prob_shape = site["log_prob"].shape
value_shape = site["value"].shape
if len(log_prob_shape) < len(value_shape):
dims = np.prod(value_shape[len(log_prob_shape):])
else:
dims = 1.
nats_per_dim[name] = -site["log_prob"].mean() / dims
if self.hparams.validate:
print(f'at site {name} with dim {dims} and nats: {nats_per_dim[name]} and logprob: {log_probs[name]}')
if torch.any(torch.isnan(nats_per_dim[name])):
raise ValueError('got nan')
return log_probs, nats_per_dim
def prep_batch(self, batch):
x = batch['image'].float()
thickness = batch['thickness'].unsqueeze(1).float()
intensity = batch['intensity'].unsqueeze(1).float()
x = torch.nn.functional.pad(x, (2, 2, 2, 2))
x += torch.rand_like(x)
x = x.reshape(-1, 1, 32, 32)
return {'x': x, 'thickness': thickness, 'intensity': intensity}
def training_step(self, batch, batch_idx):
batch = self.prep_batch(batch)
log_probs, nats_per_dim = self.get_logprobs(**batch)
loss = torch.stack(tuple(nats_per_dim.values())).sum()
if torch.isnan(loss):
self.logger.experiment.add_text('nan', f'nand at {self.current_epoch}')
raise ValueError('loss went to nan')
lls = {(f'train/log p({k})'): v for k, v in log_probs.items()}
nats_per_dim = {('train/' + k + '_nats_per_dim'): v for k, v in nats_per_dim.items()}
tensorboard_logs = {'train/loss': loss, **nats_per_dim, **lls}
self.log_dict(tensorboard_logs)
return loss
def validation_step(self, batch, batch_idx):
batch = self.prep_batch(batch)
log_probs, nats_per_dim = self.get_logprobs(**batch)
loss = torch.stack(tuple(nats_per_dim.values())).sum()
lls = {(f'train/log p({k})'): v for k, v in log_probs.items()}
nats_per_dim = {(k + '_nats_per_dim'): v for k, v in nats_per_dim.items()}
return {'loss': loss, **lls, **nats_per_dim}
def test_step(self, batch, batch_idx):
batch = self.prep_batch(batch)
log_probs, nats_per_dim = self.get_logprobs(**batch)
loss = torch.stack(tuple(nats_per_dim.values())).sum()
lls = {(f'train/log p({k})'): v for k, v in log_probs.items()}
nats_per_dim = {(k + '_nats_per_dim'): v for k, v in nats_per_dim.items()}
return {'loss': loss, **lls, **nats_per_dim}
EXPERIMENT_REGISTRY[NormalisingFlowsExperiment.__name__] = NormalisingFlowsExperiment
class ConditionalFlowSEM(BaseFlowSEM):
def __init__(self, use_affine_ex: bool = True, **kwargs):
super().__init__(**kwargs)
self.use_affine_ex = use_affine_ex
# decoder parts
# Flow for modelling t Gamma
self.thickness_flow_components = ComposeTransformModule([Spline(1)])
self.thickness_flow_constraint_transforms = ComposeTransform([self.thickness_flow_lognorm, ExpTransform()])
self.thickness_flow_transforms = ComposeTransform([self.thickness_flow_components, self.thickness_flow_constraint_transforms])
# affine flow for s normal
intensity_net = DenseNN(1, [1], param_dims=[1, 1], nonlinearity=torch.nn.Identity())
self.intensity_flow_components = ConditionalAffineTransform(context_nn=intensity_net, event_dim=0)
self.intensity_flow_constraint_transforms = ComposeTransform([SigmoidTransform(), self.intensity_flow_norm])
self.intensity_flow_transforms = [self.intensity_flow_components, self.intensity_flow_constraint_transforms]
# build flow as s_affine_w * t * e_s + b -> depends on t though
# realnvp or so for x
self._build_image_flow()
def _build_image_flow(self):
self.trans_modules = ComposeTransformModule([])
self.x_transforms = []
self.x_transforms += [self._get_preprocess_transforms()]
c = 1
for _ in range(self.num_scales):
self.x_transforms.append(SqueezeTransform())
c *= 4
for _ in range(self.flows_per_scale):
if self.use_actnorm:
actnorm = ActNorm(c)
self.trans_modules.append(actnorm)
self.x_transforms.append(actnorm)
gcp = GeneralizedChannelPermute(channels=c)
self.trans_modules.append(gcp)
self.x_transforms.append(gcp)
self.x_transforms.append(TransposeTransform(torch.tensor((1, 2, 0))))
ac = ConditionalAffineCoupling(c // 2, BasicFlowConvNet(c // 2, self.hidden_channels, (c // 2, c // 2), 2))
self.trans_modules.append(ac)
self.x_transforms.append(ac)
self.x_transforms.append(TransposeTransform(torch.tensor((2, 0, 1))))
gcp = GeneralizedChannelPermute(channels=c)
self.trans_modules.append(gcp)
self.x_transforms.append(gcp)
self.x_transforms += [
ReshapeTransform((4**self.num_scales, 32 // 2**self.num_scales, 32 // 2**self.num_scales), (1, 32, 32))
]
if self.use_affine_ex:
affine_net = DenseNN(2, [16, 16], param_dims=[1, 1])
affine_trans = ConditionalAffineTransform(context_nn=affine_net, event_dim=3)
self.trans_modules.append(affine_trans)
self.x_transforms.append(affine_trans)
@pyro_method
def pgm_model(self):
thickness_base_dist = Normal(self.thickness_base_loc, self.thickness_base_scale).to_event(1)
thickness_dist = TransformedDistribution(thickness_base_dist, self.thickness_flow_transforms)
thickness = pyro.sample('thickness', thickness_dist)
thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
# pseudo call to thickness_flow_transforms to register with pyro
_ = self.thickness_flow_components
intensity_base_dist = Normal(self.intensity_base_loc, self.intensity_base_scale).to_event(1)
intensity_dist = ConditionalTransformedDistribution(intensity_base_dist, self.intensity_flow_transforms).condition(thickness_)
intensity = pyro.sample('intensity', intensity_dist)
# pseudo call to w_flow_transforms to register with pyro
_ = self.intensity_flow_components
return thickness, intensity
@pyro_method
def model(self):
thickness, intensity = self.pgm_model()
thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
intensity_ = self.intensity_flow_norm.inv(intensity)
context = torch.cat([thickness_, intensity_], 1)
x_base_dist = Normal(self.x_base_loc, self.x_base_scale).to_event(3)
cond_x_transforms = ComposeTransform(ConditionalTransformedDistribution(x_base_dist, self.x_transforms).condition(context).transforms).inv
cond_x_dist = TransformedDistribution(x_base_dist, cond_x_transforms)
x = pyro.sample('x', cond_x_dist)
return x, thickness, intensity
@pyro_method
def infer_thickness_base(self, thickness):
return self.thickness_flow_transforms.inv(thickness)
@pyro_method
def infer_intensity_base(self, thickness, intensity):
intensity_base_dist = Normal(self.intensity_base_loc, self.intensity_base_scale)
thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
cond_intensity_transforms = ComposeTransform(
ConditionalTransformedDistribution(intensity_base_dist, self.intensity_flow_transforms).condition(thickness_).transforms)
return cond_intensity_transforms.inv(intensity)
@pyro_method
def infer_x_base(self, thickness, intensity, x):
x_base_dist = Normal(self.x_base_loc, self.x_base_scale)
thickness_ = self.thickness_flow_constraint_transforms.inv(thickness)
intensity_ = self.intensity_flow_norm.inv(intensity)
context = torch.cat([thickness_, intensity_], 1)
cond_x_transforms = ComposeTransform(ConditionalTransformedDistribution(x_base_dist, self.x_transforms).condition(context).transforms)
return cond_x_transforms(x)
@classmethod
def add_arguments(cls, parser):
parser = super().add_arguments(parser)
parser.add_argument(
'--use_affine_ex', default=False, action='store_true', help="whether to use conditional affine transformation on e_x (default: %(default)s)")
return parser```