Hi Xiucheng, that's a great question! It would appear reasonable that IAF should sum out the last dimension as you suggest. However, changing the code to
log_scale.sum(dim=1, keepdim=True) would result in the IAF density being miscalculated... Essentially, why
log_abs_det_jacobian doesn't sum over the final dimension is a bit of hack, required since it is typically built on top of a univariate distribution. Let me elaborate.
You first create the base distribution with, e.g.,
base_dist = dist.Normal(torch.zeros(10), torch.ones(10)).
Then the IAF distribution is created with
iaf = dist.InverseAutoregressiveFlow(AutoRegressiveNN(10, ))
iaf_dist = dist.TransformedDistribution(base_dist, [iaf])
Let's examine how the density is calculated. Going to the PyTorch code for TransformDistribution, the normalizing flow equation is calculated as,
x = transform.inv(y)
log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y), event_dim - transform.event_dim)
log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape))
In our example,
transform.log_abs_det_jacobian(x, y) ~  and
self.base_dist.log_prob(y) ~ , so
log_prob returns a vector where each element is
log(p(x_i)) - log(sigma_i), and summing the elements of this vector gives the correct density for IAF (see Eq (11) of the IAF paper).
If we were to change the code then, we would have
transform.log_abs_det_jacobian(x, y) ~ , and
log_prob would return a vector where each element is
log(p(x_i)) - sum_i(log(sigma_i)). Summing this gives the wrong result (you can verify the difference numerically)!
The problem here really is that IAF should be taking a multivariate base distribution such as a multivariate normal with diagonal covariance. This requires putting a Pyro wrapper on
LowRankMultivariateNormal, rewriting how IAF is used in the DMM tutorial example, and so on, but I think it's a good idea. So I'm going to put together a proposal for this improvement and submit a pull request in the next few days