Using SVI to get conditional prob for data batch


I’m struggling with a basic inference problem with pyro, here is my question:
Having a probablity graph model with fixed parameter, how to predict the nodes’ conditional prob with SVI for each input data.

As in the simple example below, b depends on a and the parameter has been fixed in weight. Given the value of a, I want to get the conditional prob of b.

The most straight way is to use compute_marginal() to get the precise result, which is obivious in this example. But when graph is dense, the enumeration process will be very slow. So I try to overcome this with SVI, but the result is not quite right.

def model(data):

    weight = pyro.param("weight", torch.tensor([[0.3,0.7],[0.4,0.6]]), constraints.unit_interval)

    with pyro.plate("data", len(data)):
        a = pyro.sample("a", dist.Categorical(torch.tensor([0.5,0.5])), obs=data[:,0])

        weights_b = Vindex(weight)[a,:]

        b = pyro.sample("b", dist.Categorical(weights_b), infer={"enumerate": "parallel"})

def guide(data):
    poutine.block(model, hide=["weight"])

    with pyro.plate("data", len(data)):
        weight_guide = pyro.param('weight_guide', torch.tensor([0.5,0.5]),
        b = pyro.sample('b', dist.Categorical(weight_guide))

def infer():
    optim = pyro.optim.Adam({'lr': 0.2, 'betas': [0.8, 0.99]})
    elbo = TraceEnum_ELBO()
    svi = SVI(model, guide, optim, loss=elbo)

    # Register hooks to monitor gradient norms.
    gradient_norms = defaultdict(list)
    svi.loss(model, guide, data)  # Initializes param store.
    for name, value in pyro.get_param_store().named_parameters():
        value.register_hook(lambda g, name=name: gradient_norms[name].append(g.norm().item()))

    losses = []
    for i in range(1000 if not smoke_test else 2):
        loss = svi.step(data)
        print('.' if i % 100 else '\n', end='')

if __name__ == '__main__':
    data = torch.tensor([[0],[1]])
    print('weight_guide = {}'.format(pyro.param('weight_guide')))
    print('weight = {}'.format(pyro.param('weight')))

The result is :

weight_guide = tensor([0.582313, 0.814295], grad_fn=<ClampBackward>)
weight = tensor([[0.584472, 0.816690],
        [0.593544, 0.829365]], grad_fn=<ClampBackward>)

It seems no matter the value or dimension of the weight_guide are both wrong. And the poutine block seems not working because the value of weight changes, .

So where is the problem of this code?


Hi @lyy, I don’t fully understand your (model,guide) pair, but here are some suggestions for your guide:

First the statement poutine.block(model, hide=["weight"]) has no effect. I’m not sure its intended purpose, but I believe you can simply remove it.

Second, the statement pyro.param('weight_guide', ...) actually acts as a global parameter, whereas I believe you intend to make it a local parameter (learning one set of weights per datum). To make it a local parameter, you’d need to specify an event_dim so Pyro knows how to replicate it across the dataset. Also since those weights are probabilities, you should use a simplex constraint (or use Bernoulli with a unit_interval constraint):

# Version 1. using Categorical with constraints.simplex:
with pyro.plate("data", len(data)):
    weight_guide = pyro.param("weight_guide", torch.tensor([0.5, 0.5]),
    b = pyro.sample("b", dist.Categorical(weight_guide))


# Version 2. using Bernoulli with constraints.unit_interval:
with pyro.plate("data", len(data)):
    weight_guide = pyro.param("weight_guide", torch.tensor(0.5),
    b = pyro.sample("b", dist.Bernoulli(weight_guide))

Third, I believe your overall task could be accomplished more simply by indexing into the weights tensor, but I may be misunderstanding:

marginals_b = weight[data]  # does this directly compute marginals?

@fritzo Thanks for the reply,

In fact I write this guide by refering to the snippets of GMM example:

def full_guide(data):
    # Global variables.
    with poutine.block(hide_types=["param"]):  # Keep our learned values of global parameters.

    # Local variables.
    with pyro.plate('data', len(data)):
        assignment_probs = pyro.param('assignment_probs', torch.ones(len(data), K) / K,
        pyro.sample('assignment', dist.Categorical(assignment_probs))

In GMM example, the parameter has been learned and fixed for the prediction of membership for each data input, which is exactly what I want. I feel its structure is similar to my example, but it uses poutine block without event_dim to accomplish this. This makes me more confused after reading your answer.

For your third question, it definitely can be calculated using index. However, in my real problem scenario, there are more nodes with much more categories.
For example, a with 10000, b with 10000, c with 10000, d with 10000, e with 10000. And the relation is a -> b -> c -> d -> e . With input data batch of a , it`s quite easy to get the conditional prob of b by using index. But as to e , the enumeration complexity will be 10000**4 for each input. This is why I want to refer to SVI as another solution.

This is my first time utilizing PPL, obviously I don`t have much insight in it. I will try to understand your answer.

Hope I have make it clear for you, and thanks again for you kind reply :grinning: