Defining dependent constrains

Hi all,

I am trying to implement a model in which I have, among other parameters, 4 parameters for which I want to obtain the MAP estimates.

For these 4 parameters (wb1, wb2, wb3, wb4) there are some constraints: they should all be positive and wb2 >= wb1, wb3 >=wb2, and wb4 >= wb3.

Currently I have implemented this as:

def model(data):
    ...
    wb1 = pyro.sample("wb1", dist.Uniform(0, 10))
    wb2 = pyro.sample("wb2", dist.Uniform(0, 10))
    wb3 = pyro.sample("wb3", dist.Uniform(0, 10))
    wb4 = pyro.sample("wb4", dist.Uniform(0, 10))
    ...

def guide(data):
    ...
    wb1_loc = pyro.param("wb1_loc", lambda: torch.tensor(0.2))
    wb2_loc = pyro.param(
        "wb2_loc",
        lambda: torch.tensor(0.4),
        constraint=constraints.greater_than(wb1_loc),
    )
    wb3_loc = pyro.param(
        "wb3_loc",
        lambda: torch.tensor(0.6),
        constraint=constraints.greater_than(wb2_loc),
    )
    wb4_loc = pyro.param(
        "wb4_loc",
        lambda: torch.tensor(0.8),
        constraint=constraints.greater_than(wb3_loc),
    )

    # MAP estimates for wb
    wb1 = pyro.sample("wb1", dist.Delta(wb1_loc))
    wb2 = pyro.sample("wb2", dist.Delta(wb2_loc))
    wb3 = pyro.sample("wb3", dist.Delta(wb3_loc))
    wb4 = pyro.sample("wb4", dist.Delta(wb4_loc))
    ...

I run inference using SVI. The other parameters in my model get updated like I expect, but these four wb parameters do not.

I tried two ways to inspect the values after inference. First by looking in the param_store, here all wb*_loc are equal to their init values. Secondly by using pyro.infer.Predictive with return_sites=(other_params, wb1, wb2, wb3, wb4 ,), same problem.

What am I doing wrong?

Thanks a lot!

Hi @eppow here’s an idea: run inference with unordered parameters and record the ordered parameters:

def model(data):
    ...
    wb = pyro.sample("wb", dist.Uniform(torch.zeros(4), 10).to_event(1))
    wb1, wb2, wb3, wb4 = wb.sort(-1).values.unbind(-1)
    ...

If you wan these accessible in the trace you can add pyro.deterministic statements

def model(data):
    ...
    wb = pyro.sample("wb", dist.Uniform(torch.zeros(4), 10).to_event(1))
    wb1, wb2, wb3, wb4 = wb.sort(-1).values.unbind(-1)
    pyro.deterministic("wb1", wb1)
    pyro.deterministic("wb2", wb2)
    pyro.deterministic("wb3", wb3)
    pyro.deterministic("wb4", wb4)
    ...

Hi @fritzo,

Thanks for the suggestion! This seems a better way to do it indeed.

My variables still aren’t being updated, but I suppose there is something unrelated to this topic going on then…

Thanks!

Note you’ll now need a pyro.sample("wb", ...) in the guide, rather than the decomposed sample statements.