Transpose a distribution when input to a flow

I have a random variable X in R^{n, d}. “n” here can be interpreted as datapoints, and “d” are dimensions of interest.

Each dimension of X is independent, so X[:, i] ~ MVN(mu_i, sigma_i), however, I need to pass each datapoint in X[j, :] through a flow (of dimension d).

This throws an error:

import torch
import pyro.distributions as dist

n = 5; d = 2
mu = torch.ones((d, n))
sigma = torch.cat([torch.eye(n).unsqueeze(0) for i in range(d)])

mvn = dist.MultivariateNormal(mu, sigma)
flow = dist.TransformedDistribution(mvn, [dist.transforms.Planar(d)])
flow.sample() # error

Is there a way to achieve what I need in pyro?

I’ve tried to define a new transform along the lines of:

class _Transpose(dist.torch_transform.TransformModule):
    def _call(self, X):
        return X.T
    def _inverse(self, X):
        return X.T
    def log_abs_det_jacobian(self, *args):
        return tt(0.0)

But it throws all kinds of errors when I try to run it with SVI (mismatched array shapes, cache misses)

Hi @aditya, here’s a working version of your snippet. See the “Distribution shapes” section of our tensor shape tutorial for background.

import torch
import pyro.distributions as dist

n = 5; d = 2
mu = torch.ones((n, d))  # batch dimensions to the left: (n, d), not (d, n)
sigma = torch.ones((n, d))  # save space: diagonal instead of full dxd identity matrix

mvn = dist.Normal(mu, sigma)  # independent Normals
mvn = mvn.to_event(1)  # indicate that samples are vectors: mark rightmost dim as "event dimension"

flow = dist.TransformedDistribution(mvn, [dist.transforms.Planar(d)])
flow.sample()

Thanks - X is expected to be independent across dimensions, but not datapoints. Although I’ve shown sigma as a diagonal matrix above, for my use case, it is full.

I see, in that case you could use something like your _Transpose but using .transpose rather than .T so that batch dimensions are not affected - does this work in your Pyro model?

import torch
import pyro.distributions as dist

class _Transpose(dist.transforms.Transform):
    def __init__(self):
        super().__init__(cache_size=1)  # cache forward computation so Planar works
    def _call(self, X):
        return X.transpose(-1, -2)  # only transpose the rightmost two dimensions
    def _inverse(self, X):
        return X.transpose(-1, -2)
    def log_abs_det_jacobian(self, *args):
        return torch.tensor(0.0)

n = 5; d = 2
mu = torch.ones((d, n))
sigma = torch.cat([torch.eye(n).unsqueeze(0) for i in range(d)])

mvn = dist.MultivariateNormal(mu, sigma).to_event(1)  # note to_event for Pyro's benefit here
flow = dist.TransformedDistribution(mvn, [_Transpose(), dist.transforms.Planar(d), _Transpose()])
assert flow.sample().shape == (d, n)