Hi everyone. I have several custom written distributions which are subclasses of either TorchDistribution or TransformedDistribution. How do I make them funsor compatible? Thanks.
Hi @ordabayev, for
TorchDistribution subclasses you can use
funsor.distribution.make_dist, which takes a
TorchDistribution subclass and (optionally) the names of its parameters and returns a new
class CustomTorchDistribution(TorchDistribution): def __init__(self, param1, param2): ... CustomFunsorDistribution = funsor.distribution.make_dist(CustomTorchDistribution, ("param1", "param2"))
Note also that for
make_dist to generate a correct wrapper, you’ll need to make sure to implement
CustomTorchDistribution.arg_constraints. You may also need to implement a pattern for automatic conversion with
funsor.to_funsor, although this is probably something we should automate and include in
@funsor.to_funsor.register(CustomTorchDistribution) def mydist_to_funsor(backend_dist, output=None, dim_to_name=None): funsor_param1 = funsor.to_funsor(backend_dist.param1, output=CustomFunsorDistribution._infer_param_domain("param1", backend_dist.param1.shape), dim_to_name=dim_to_name) funsor_param2 = funsor.to_funsor(backend_dist.param2, output=CustomFunsorDistribution._infer_param_domain("param2", backend_dist.param2.shape), dim_to_name=dim_to_name) return CustomFunsorDistribution(funsor_param1, funsor_param2)
Please feel free to ask more questions or open bug reports or feature requests if you have any trouble.
We’re still working on full support for
TransformedDistributions in Funsor; see this PR for details and this issue for a tracker of progress toward broader distribution API coverage. New use cases are very helpful in constraining our design, but if we can’t add support for yours soon enough it should still be possible to work around this missing support if all you need to do is evaluate log-probabilities and sample.