Question about the implementation of InverseAutoregressiveFlow

The current log_abs_det_jacobian of IAF directly returns log_scale in this line, shouldn’t it be sum over the data dimension, i.e., log_scale.sum(dim=1, keepdim=True)?

Hi Xiucheng, that’s a great question! It would appear reasonable that IAF should sum out the last dimension as you suggest. However, changing the code to log_scale.sum(dim=1, keepdim=True) would result in the IAF density being miscalculated… Essentially, why log_abs_det_jacobian doesn’t sum over the final dimension is a bit of hack, required since it is typically built on top of a univariate distribution. Let me elaborate.

You first create the base distribution with, e.g.,

base_dist = dist.Normal(torch.zeros(10), torch.ones(10)).

Then the IAF distribution is created with

iaf = dist.InverseAutoregressiveFlow(AutoRegressiveNN(10, [40])) iaf_dist = dist.TransformedDistribution(base_dist, [iaf])

Let’s examine how the density is calculated. Going to the PyTorch code for TransformDistribution, the normalizing flow equation is calculated as,

x = transform.inv(y) log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y), event_dim - transform.event_dim)
etc…
log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape))

In our example, transform.log_abs_det_jacobian(x, y) ~ [10] and self.base_dist.log_prob(y) ~ [10], so log_prob returns a vector where each element is log(p(x_i)) - log(sigma_i), and summing the elements of this vector gives the correct density for IAF (see Eq (11) of the IAF paper).

If we were to change the code then, we would have transform.log_abs_det_jacobian(x, y) ~ [], and log_prob would return a vector where each element is log(p(x_i)) - sum_i(log(sigma_i)). Summing this gives the wrong result (you can verify the difference numerically)!

The problem here really is that IAF should be taking a multivariate base distribution such as a multivariate normal with diagonal covariance. This requires putting a Pyro wrapper on LowRankMultivariateNormal, rewriting how IAF is used in the DMM tutorial example, and so on, but I think it’s a good idea. So I’m going to put together a proposal for this improvement and submit a pull request in the next few days

1 Like

Hi @stefanwebb Thanks for your detailed explanation. It does answer my question. It seems that in the current implementation as long as the base distribution is univariate distribution, the calculated result would be correct.

By the way, I post this question on the PyTorch forum but got no reply, can you help check that?

Sure, see my post on the other forum! :slight_smile: