Hi everyone. I have several custom written distributions which are subclasses of either TorchDistribution or TransformedDistribution. How do I make them funsor compatible? Thanks.
Hi @ordabayev, for TorchDistribution
subclasses you can use funsor.distribution.make_dist
, which takes a TorchDistribution
subclass and (optionally) the names of its parameters and returns a new funsor.distribution.Distribution
subclass:
class CustomTorchDistribution(TorchDistribution):
def __init__(self, param1, param2):
...
CustomFunsorDistribution = funsor.distribution.make_dist(CustomTorchDistribution, ("param1", "param2"))
Note also that for make_dist
to generate a correct wrapper, you’ll need to make sure to implement CustomTorchDistribution.support
and CustomTorchDistribution.arg_constraints
. You may also need to implement a pattern for automatic conversion with funsor.to_funsor
, although this is probably something we should automate and include in make_dist
:
@funsor.to_funsor.register(CustomTorchDistribution)
def mydist_to_funsor(backend_dist, output=None, dim_to_name=None):
funsor_param1 = funsor.to_funsor(backend_dist.param1, output=CustomFunsorDistribution._infer_param_domain("param1", backend_dist.param1.shape), dim_to_name=dim_to_name)
funsor_param2 = funsor.to_funsor(backend_dist.param2, output=CustomFunsorDistribution._infer_param_domain("param2", backend_dist.param2.shape), dim_to_name=dim_to_name)
return CustomFunsorDistribution(funsor_param1, funsor_param2)
Please feel free to ask more questions or open bug reports or feature requests if you have any trouble.
We’re still working on full support for TransformedDistribution
s in Funsor; see this PR for details and this issue for a tracker of progress toward broader distribution API coverage. New use cases are very helpful in constraining our design, but if we can’t add support for yours soon enough it should still be possible to work around this missing support if all you need to do is evaluate log-probabilities and sample.