How to make custom distributions funsor compatible?

Hi everyone. I have several custom written distributions which are subclasses of either TorchDistribution or TransformedDistribution. How do I make them funsor compatible? Thanks.

Hi @ordabayev, for TorchDistribution subclasses you can use funsor.distribution.make_dist, which takes a TorchDistribution subclass and (optionally) the names of its parameters and returns a new funsor.distribution.Distribution subclass:

class CustomTorchDistribution(TorchDistribution):
    def __init__(self, param1, param2):

CustomFunsorDistribution = funsor.distribution.make_dist(CustomTorchDistribution, ("param1", "param2"))

Note also that for make_dist to generate a correct wrapper, you’ll need to make sure to implement and CustomTorchDistribution.arg_constraints. You may also need to implement a pattern for automatic conversion with funsor.to_funsor, although this is probably something we should automate and include in make_dist:

def mydist_to_funsor(backend_dist, output=None, dim_to_name=None):
    funsor_param1 = funsor.to_funsor(backend_dist.param1, output=CustomFunsorDistribution._infer_param_domain("param1", backend_dist.param1.shape), dim_to_name=dim_to_name)
    funsor_param2 = funsor.to_funsor(backend_dist.param2, output=CustomFunsorDistribution._infer_param_domain("param2", backend_dist.param2.shape), dim_to_name=dim_to_name)
    return CustomFunsorDistribution(funsor_param1, funsor_param2)

Please feel free to ask more questions or open bug reports or feature requests if you have any trouble.

We’re still working on full support for TransformedDistributions in Funsor; see this PR for details and this issue for a tracker of progress toward broader distribution API coverage. New use cases are very helpful in constraining our design, but if we can’t add support for yours soon enough it should still be possible to work around this missing support if all you need to do is evaluate log-probabilities and sample.

1 Like

Hi @eb8680_2, I’ve tried out your PR. So far it worked for one of my distributions. Just as a note, apart from .support and .arg_constraints, I had to make sure that

  1. distributions parameters can be accessed as properties .param1 and .param2
  2. __init__ method takes validate_args argument
1 Like