I have the following numpyro code, which uses a transformed distribution to ensure a pair of normally distributed variables are positive and ordered.
(Context: I’m sampling a dataset of N elements, where each element is a pair of points p = (x0, x1) with the constraint that x1 > x0 and x0, x1 > 0. x0 and x1 are each sampled from a mixture model of two clusters. The cluster means/stdevs for x0 and x1 are different, but the cluster assignment is the same.)
def positive_ordered_model(N, pi, x0_prior, x1_prior):
with numpyro.plate("N", N):
#i = cluster assignment
i = numpyro.sample("i", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
mean_x0 = x0_prior[i, 0]
mean_x1 = x1_prior[i, 0]
std_x0 = x0_prior[i, 1]
std_x1 = x1_prior[i, 1]
mean = jnp.dstack([mean_x0, mean_x1])
std = jnp.dstack([std_x0, std_x1])
#ensure the two elements of p are ordered and positive
p = numpyro.sample("p", dist.TransformedDistribution(
dist.Normal(mean, std),
transforms.ComposeTransform([transforms.ExpTransform(), transforms.OrderedTransform()])
))
#take along last axis, deal with batching and variable number of dimensions
p0 = numpyro.deterministic("p0", p.take(0, axis=-1))
p1 = numpyro.deterministic("p1", p.take(1, axis=-1))
I am using a ComposeTransform
of ExpTransform
and OrderedTransform
for this, but it does not seem to do what I want. It seems only one of the above two transforms works, depending on the order in which the transforms are passed.
If I use
transforms.ComposeTransform([transforms.ExpTransform(), transforms.OrderedTransform()])
then all the resulting samples of p
are ordered (second element is always bigger than the first element) but not all of them are positive.
If instead I use
transforms.ComposeTransform([transforms.OrderedTransform(), transforms.ExpTransform()])
then all samples of p
are positive but not all are ordered.
See this Google Colab notebook for more details - Numpyro ComposeTransform not working for Positive and Ordered Transform
Any help is appreciated. Thanks in advance.