Hi all,
I was wondering if there is a way to “stack” two independent univariate distributions p_1(x), p_2(y) to obtain a Distribution object for their joint distribution p(x, y) = p_1(x) * p_2(y). This would allow me to conveniently use transforms for normalizing flows such as SplineCoupling, when using a heterogenous set of base distributions for the individual dimensions.
There’s no preexisting functionality for this, but it should be straightforward to implement a new distribution that does what you want, using e.g. the source code of MaskedMixture
as a reference. Here’s a snippet to get you started:
class StackDistribution(pyro.distributions.torch_distribution.TorchDistribution):
def __init__(self, *base_dists, validate_args=True):
assert len(set(base_dist.event_shape for base_dist in base_dists)) == 1, \
"All base_dists should have the same event_shape"
batch_shape = broadcast_shape(*[base_dist.batch_shape for base_dist in base_dists])
event_shape = base_dists[0].event_shape + (len(base_dists),)
self.base_dists = tuple(base_dist.expand(batch_shape) for base_dist in base_dists)
super().__init__(batch_shape, event_shape, validate_args)
...
def log_prob(self, x):
return sum(base_dist.log_prob(x[..., i]) for i, base_dist in enumerate(self.base_dists))
def rsample(self, sample_shape=torch.Size()):
return torch.stack([base_dist.rsample(sample_shape) for base_dist in self.base_dists], dim=-1)
thank you, that worked out nicely
2 Likes