Update rules for constrained parameters

Hello,
I am trying to implement a variation of ELBO called GECO (see paper [1810.00597] Taming VAEs).

It seems that pyro has all the machinery necessary to implement this fairly easily. I just need to know the details of how a pyro updates constrained parameters.

For example assume that your loss function is:

Loss = f + mu *g

where mu was declared as a pyro.param with constraint=constraint.unit_interval.

What is the update rule for mu?
For an unconstrained parameter it is mu -= learning_rate * dLoss / d(mu)
What is the update rule if mu is constrained to be in the unit interval?
What is the update rule if mu is constrained to be positive?

Thanks!

these are just pytorch constraints. so for example to get something constrained to be in the unit interval by default one applies a sigmoid transformation to an unconstrained parameter. see here. the transformation used for positive is the exponential, etc.

That’s right, Pyro doesn’t actually update mu directly, instead it creates another constrained parameter mu_unconstrained and uses standard Adam or whatever to update that parameter. Then when you ask for the latest mu, pyro computes

mu = transform_to(constraints.unit_interval)(mu_unconstrained)

where transform_to(...) just returns a sigmoid transform.