How to make my Reparam return the square root of variable?

I have implemented a reparameterization of the half-student-t distribution following “Mean field variational Bayes for elaborate distributions. Bayesian Analysis, 6(4), 2011”:

As you can see, this reparameterization returns x^2, so the reparam needs to return the square root. How can I do so?

It is implemented as followings:

# TODO: NOT DONE! this returns x^2, not x
class HalfStudentTReparam(Reparam):
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]

        fn, event_dim = self._unwrap(fn)
        assert isinstance(fn, HalfStudentT)

        rate = 1 / (fn.scale*fn.scale)
        a = pyro.sample(f'{name}_invgamma', 
            self._wrap(dist.InverseGamma(0.5, rate), event_dim))
        new_fn = self._wrap(dist.InverseGamma(fn.df*0.5, fn.df/a), event_dim)
        return {'fn': new_fn, 'value': value, 'is_observed': is_observed}

Since the half-student-t distribution is not part of pyro and pytorch, I offer the code as well. Not sure if it is correct though…

class HalfStudentT(TransformedDistribution, TorchDistributionMixin):
    arg_constraints = {
        'df': constraints.positive,
        'scale': constraints.positive,
    support = constraints.nonnegative
    has_rsample = True

    def __init__(self, df, scale=1.0, validate_args=None):
        base_dist = StudentT(df, scale=scale, validate_args=False)
        super().__init__(base_dist, AbsTransform(), validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(HalfStudentT, _instance)
        return super().expand(batch_shape, _instance=new)

    def scale(self):
        return self.base_dist.scale

    def df(self):
        return self.base_dist.df

    def log_prob(self, value):
        if self._validate_args:
        log_prob = self.base_dist.log_prob(value) + math.log(2)
        log_prob = torch.where(value >= 0, log_prob, -inf)
        return log_prob

Hi @mochar, I think you can define a SquareRoot transform with positive domain and codomain. Then you can do

new_fn = self._wrap(dist.TransformedDistribution(dist.InverseGamma(...), SquareRoot()), ...)

Thank you so much!

Does this implementation of the square root transform look good?

class SquareRootTransform(Transform):
    def __init__(self):

    def __call__(self, x):
        return torch.sqrt(x)

    def _inverse(self, y):
        return y ** 2

    def log_abs_det_jacobian(self, x, y):
        return -0.5 * torch.log(x)
    def domain(self):
        return constraints.positive

    def codomain(self):
        return constraints.positive

    def with_cache(self, cache_size=0):
        return self