I have a random variable X in R^{n, d}. “n” here can be interpreted as datapoints, and “d” are dimensions of interest.
Each dimension of X is independent, so X[:, i] ~ MVN(mu_i, sigma_i), however, I need to pass each datapoint in X[j, :] through a flow (of dimension d).
This throws an error:
import torch
import pyro.distributions as dist
n = 5; d = 2
mu = torch.ones((d, n))
sigma = torch.cat([torch.eye(n).unsqueeze(0) for i in range(d)])
mvn = dist.MultivariateNormal(mu, sigma)
flow = dist.TransformedDistribution(mvn, [dist.transforms.Planar(d)])
flow.sample() # error
import torch
import pyro.distributions as dist
n = 5; d = 2
mu = torch.ones((n, d)) # batch dimensions to the left: (n, d), not (d, n)
sigma = torch.ones((n, d)) # save space: diagonal instead of full dxd identity matrix
mvn = dist.Normal(mu, sigma) # independent Normals
mvn = mvn.to_event(1) # indicate that samples are vectors: mark rightmost dim as "event dimension"
flow = dist.TransformedDistribution(mvn, [dist.transforms.Planar(d)])
flow.sample()
Thanks - X is expected to be independent across dimensions, but not datapoints. Although I’ve shown sigma as a diagonal matrix above, for my use case, it is full.
I see, in that case you could use something like your _Transpose but using .transpose rather than .T so that batch dimensions are not affected - does this work in your Pyro model?
import torch
import pyro.distributions as dist
class _Transpose(dist.transforms.Transform):
def __init__(self):
super().__init__(cache_size=1) # cache forward computation so Planar works
def _call(self, X):
return X.transpose(-1, -2) # only transpose the rightmost two dimensions
def _inverse(self, X):
return X.transpose(-1, -2)
def log_abs_det_jacobian(self, *args):
return torch.tensor(0.0)
n = 5; d = 2
mu = torch.ones((d, n))
sigma = torch.cat([torch.eye(n).unsqueeze(0) for i in range(d)])
mvn = dist.MultivariateNormal(mu, sigma).to_event(1) # note to_event for Pyro's benefit here
flow = dist.TransformedDistribution(mvn, [_Transpose(), dist.transforms.Planar(d), _Transpose()])
assert flow.sample().shape == (d, n)