How to make my Reparam return the square root of variable?

I have implemented a reparameterization of the half-student-t distribution following “Mean field variational Bayes for elaborate distributions. Bayesian Analysis, 6(4), 2011”:

As you can see, this reparameterization returns x^2, so the reparam needs to return the square root. How can I do so?

It is implemented as followings:

# TODO: NOT DONE! this returns x^2, not x
class HalfStudentTReparam(Reparam):
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]

        fn, event_dim = self._unwrap(fn)
        assert isinstance(fn, HalfStudentT)

        rate = 1 / (fn.scale*fn.scale)
        a = pyro.sample(f'{name}_invgamma', 
            self._wrap(dist.InverseGamma(0.5, rate), event_dim))
        
        new_fn = self._wrap(dist.InverseGamma(fn.df*0.5, fn.df/a), event_dim)
        return {'fn': new_fn, 'value': value, 'is_observed': is_observed}

Since the half-student-t distribution is not part of pyro and pytorch, I offer the code as well. Not sure if it is correct though…

class HalfStudentT(TransformedDistribution, TorchDistributionMixin):
    arg_constraints = {
        'df': constraints.positive,
        'scale': constraints.positive,
    }
    support = constraints.nonnegative
    has_rsample = True

    def __init__(self, df, scale=1.0, validate_args=None):
        base_dist = StudentT(df, scale=scale, validate_args=False)
        super().__init__(base_dist, AbsTransform(), validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(HalfStudentT, _instance)
        return super().expand(batch_shape, _instance=new)

    @property
    def scale(self):
        return self.base_dist.scale

    @property
    def df(self):
        return self.base_dist.df

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        log_prob = self.base_dist.log_prob(value) + math.log(2)
        log_prob = torch.where(value >= 0, log_prob, -inf)
        return log_prob

A small innocent bump