I was wondering if there is a way to “stack” two independent univariate distributions p_1(x), p_2(y) to obtain a Distribution object for their joint distribution p(x, y) = p_1(x) * p_2(y). This would allow me to conveniently use transforms for normalizing flows such as SplineCoupling, when using a heterogenous set of base distributions for the individual dimensions.
There’s no preexisting functionality for this, but it should be straightforward to implement a new distribution that does what you want, using e.g. the source code of
MaskedMixture as a reference. Here’s a snippet to get you started:
class StackDistribution(pyro.distributions.torch_distribution.TorchDistribution): def __init__(self, *base_dists, validate_args=True): assert len(set(base_dist.event_shape for base_dist in base_dists)) == 1, \ "All base_dists should have the same event_shape" batch_shape = broadcast_shape(*[base_dist.batch_shape for base_dist in base_dists]) event_shape = base_dists.event_shape + (len(base_dists),) self.base_dists = tuple(base_dist.expand(batch_shape) for base_dist in base_dists) super().__init__(batch_shape, event_shape, validate_args) ... def log_prob(self, x): return sum(base_dist.log_prob(x[..., i]) for i, base_dist in enumerate(self.base_dists)) def rsample(self, sample_shape=torch.Size()): return torch.stack([base_dist.rsample(sample_shape) for base_dist in self.base_dists], dim=-1)
thank you, that worked out nicely