Constrain combinations of two parameters

I previously solved a similar issue (I had a constraint that the sum of two RVs had to be positive) by using pyro.deterministic("var", torch.nn.functional.softplus(a+b)). In numpyro, you can do numpyro.deterministic("var", jax.nn.softplus(a+b)).