Try to understand param with constraints

I don’t understand why a param with constraint became a no-leaf Tensor .

Below code without constraint works well.

from torch.optim import Adam

pyro.clear_param_store()
def model():
    mu = param("mu", tensor(0.))
    return sample("x", dist.Normal(mu, 1))

model() # Instantiate the mu parameter
cond_model = condition(model, {"x": tensor(5.0)})

# Large learning rate for demonstration purposes
optimizer = Adam([param("mu")], lr=0.01)
mus = []
losses = []
for _ in range(1000):
    tr = trace(cond_model).get_trace()

    # Optimizer wants to push positive values towards zero,
    # so use negative log probability
    prob = -tr.log_prob_sum()
    prob.backward()

    # Update parameters according to optimization strategy
    optimizer.step()

    # Zero all parameter gradients so they don't accumulate
    optimizer.zero_grad()

    # Record probability (or "loss") along with current mu
    losses.append(prob.item())
    mus.append(param("mu").item())

pd.DataFrame({"mu": mus, "loss": losses}).plot(subplots=True)

Only change mu = param("mu", tensor(0.)) to mu = param("mu", tensor(0.), constraint=constraints.greater_than(0)) .

It would get error ValueError: can't optimize a non-leaf Tensor . I used pycharm debug to trace variable and found that param with constraints became a non-leaf Tensor .

I just can’t understand the mechanism here , is there anyone can help me ?

PS: Last year I post a question ValueError: can't optimize a non-leaf Tensor - #2 by martinjankowiak , until now I can’t understand the background , and I am not searching for a solution , just want to know why .

Pyro imposes constraints by creating underlying unconstrained parameters that are optimized by PyTorch, then mapping them through differentiable transforms to constrained spaces, and returning those mapped parameters from pyro.param(...). If you want to get the underlying unconstrained parameter for optimization, you can pull out the unconstrained parameter via the .unconstrained attribute:

pyro.param(name).unconstrained()