Joint distribution object

Hi all,
I was wondering if there is a way to “stack” two independent univariate distributions p_1(x), p_2(y) to obtain a Distribution object for their joint distribution p(x, y) = p_1(x) * p_2(y). This would allow me to conveniently use transforms for normalizing flows such as SplineCoupling, when using a heterogenous set of base distributions for the individual dimensions.

There’s no preexisting functionality for this, but it should be straightforward to implement a new distribution that does what you want, using e.g. the source code of MaskedMixture as a reference. Here’s a snippet to get you started:

class StackDistribution(pyro.distributions.torch_distribution.TorchDistribution):
    def __init__(self, *base_dists, validate_args=True):
        assert len(set(base_dist.event_shape for base_dist in base_dists)) == 1, \
            "All base_dists should have the same event_shape"
        batch_shape = broadcast_shape(*[base_dist.batch_shape for base_dist in base_dists])
        event_shape = base_dists[0].event_shape + (len(base_dists),)
        self.base_dists = tuple(base_dist.expand(batch_shape) for base_dist in base_dists)
        super().__init__(batch_shape, event_shape, validate_args)
        ...

    def log_prob(self, x):
        return sum(base_dist.log_prob(x[..., i]) for i, base_dist in enumerate(self.base_dists))

    def rsample(self, sample_shape=torch.Size()):
        return torch.stack([base_dist.rsample(sample_shape) for base_dist in self.base_dists], dim=-1)

thank you, that worked out nicely

2 Likes