import jax.numpy as jnp
from numpyro import distributions as dist
dist1 = dist.LogNormal(0, 1)
dist2 = dist.TransformedDistribution(
dist.Chi2(2),
dist.transforms.AffineTransform(0.0, 10, domain=dist1.support),
)
print(dist1.support, dist2.support)
mixing = dist.Categorical(probs=jnp.array([0.2, 0.8]))
dist.Mixture(mixing, [dist1, dist2])
When running the above code, I got
ValueError: All component distributions must have the same support.
The support of the two distribution are Positive(lower_bound=0.0) and GreaterThan(lower_bound=0.0). How should I make they have same support?