# How to make my Reparam return the square root of variable?

I have implemented a reparameterization of the half-student-t distribution following “Mean field variational Bayes for elaborate distributions. Bayesian Analysis, 6(4), 2011”:

As you can see, this reparameterization returns x^2, so the reparam needs to return the square root. How can I do so?

It is implemented as followings:

``````# TODO: NOT DONE! this returns x^2, not x
class HalfStudentTReparam(Reparam):
def apply(self, msg):
name = msg["name"]
fn = msg["fn"]
value = msg["value"]
is_observed = msg["is_observed"]

fn, event_dim = self._unwrap(fn)
assert isinstance(fn, HalfStudentT)

rate = 1 / (fn.scale*fn.scale)
a = pyro.sample(f'{name}_invgamma',
self._wrap(dist.InverseGamma(0.5, rate), event_dim))

new_fn = self._wrap(dist.InverseGamma(fn.df*0.5, fn.df/a), event_dim)
return {'fn': new_fn, 'value': value, 'is_observed': is_observed}
``````

Since the half-student-t distribution is not part of pyro and pytorch, I offer the code as well. Not sure if it is correct though…

``````class HalfStudentT(TransformedDistribution, TorchDistributionMixin):
arg_constraints = {
'df': constraints.positive,
'scale': constraints.positive,
}
support = constraints.nonnegative
has_rsample = True

def __init__(self, df, scale=1.0, validate_args=None):
base_dist = StudentT(df, scale=scale, validate_args=False)
super().__init__(base_dist, AbsTransform(), validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(HalfStudentT, _instance)
return super().expand(batch_shape, _instance=new)

@property
def scale(self):
return self.base_dist.scale

@property
def df(self):
return self.base_dist.df

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
log_prob = self.base_dist.log_prob(value) + math.log(2)
log_prob = torch.where(value >= 0, log_prob, -inf)
return log_prob
``````

A small innocent bump

Bumping again. Would appreciate any kind of input

Hi @mochar, I think you can define a SquareRoot transform with positive domain and codomain. Then you can do

``````new_fn = self._wrap(dist.TransformedDistribution(dist.InverseGamma(...), SquareRoot()), ...)
``````

Thank you so much!

Does this implementation of the square root transform look good?

``````class SquareRootTransform(Transform):
def __init__(self):
super().__init__()

def __call__(self, x):

def _inverse(self, y):
return y ** 2

def log_abs_det_jacobian(self, x, y):
return -0.5 * torch.log(x)

@property
def domain(self):
return constraints.positive

@property
def codomain(self):
return constraints.positive

def with_cache(self, cache_size=0):
return self
``````