Let’s see, if I understand correctly do you want to learn values of prior
that satisfy the following inequality constraint?
0.9 < (prior[2] - prior[0]) ** 2 + (prior[3] - prior[1]) ** 2 < 1.1
If so I’m not sure there’s an easy way to do this directly in Pyro. Let me first try to explain how Pyro’s constraint system works, then suggest some possible solutions.
Pyro’s constraint system allows you to learn parameters obeying simple constraints by transforming those parameters into unconstrained space and instead learning an unconstrained variable. E.g. if we want to learn a positive parameter x
, we can instead learn an unconstrained parameter x_unconstrained
and deterministically compute x = x_unconstrained.exp()
. When you write pyro.param("x", torch.tensor(1.0), constraint=positive)
, Pyro will under the hood create a different parameter “x_unconstrained”, and give you not a parameter but a deterministic function of that parameter at each learning step.
A limitation of Pyro’s constraint system is that you can’t easily combine or conjoin constraints as you would write in a constrained optimization problem. However, while Pyro doesn’t allow you to combine hard constraints, it makes it very easy to combine Bayesian priors, which can be seen as softer versions of hard constraints. To combine priors, you can simple add pyro.sample
statements to your program.
One possible solution to your problem would be to relax your rod constraint to a prior, e.g. here’s a Gaussian relaxation
rod = (prior[2] - prior[0]) ** 2 + (prior[3] - prior[1]) ** 2
pyro.sample("rod_prior", dist.Norma(1.0, 0.1), obs=rod)
or similarly you could add a pair of factor statements, e.g. here’s a softplus relaxation
rod = (prior[2] - prior[0]) ** 2 + (prior[3] - prior[1]) ** 2
temperature = 0.1 # relaxation hyperparameter
pyro.factor("rod_lb", -torch.softplus((0.9 - rod) / temperature))
pyro.factor("rod_ub", -torch.softplus((rod - 1.1) / temperature))
Another solution might be to do some algebra and try to express say prior[2]
as a function of prior[0]
and then prior[3]
as a function of prior[0:2]
. This would preserve the hard constraint in your example, but as your models become more complex this kind of trick won’t scale. Here’s some actual math for your example:
prior_0 = pyro.param("prior_0", torch.tensor(0.)) # unconstrained
prior_2_minus_0 = pyro.param(
"prior_2_minus_0",
torch.tensor(0.0),
constraint=constraints.interval(-0.1**0.5, 0.1**0.5),
)
prior_2 = prior_0 + prior_2_minus_0
prior_1 = pyro.param("prior_1", torch.tensor(0.)) # unconstrained
prior_3_minus_1_scaled = pyro.param(
"prior_3_minus_1_scaled",
torch.tensor(0.),
constraints=constraints.unit_interval,
)
prior_3 = prior_1 + prior_3_minus_1_scaled * (1 - prior_2_minus_0.abs())
prior = pyro.deterministic(
"prior", torch.cat([prior_0, prior_1, prior_2, prior_3], -1)
)
Yikes it looked simpler to relax!