Questions about the inference with discrete variable tutorial

  • What tutorial are you running?
    Inference with Discrete Latent Variable
  • What version of Pyro are you using?
    1.8.1
  • Please link or paste relevant code, and steps to reproduce.

Restriction 2: no downstream coupling
No variable outside of a vectorized plate can depend on an enumerated variable inside of that plate. This would violate Pyro’s exponential speedup assumption. For example the following model is invalid:

@config_enumerate
def invalid_model(data):
     with pyro.plate("plate", 10):  # <--- invalid vectorized plate
        x = pyro.sample("x", dist.Bernoulli(0.5))
    assert x.shape == (10,)
    pyro.sample("obs", dist.Normal(x.sum(), 1.), data)

I tried this model myself, and found that if you remove the assert line, the code run smoothly and didn’t complain about the structure.
The assert line will raise exception which is due to model run and svi run, which i believe is irrelevant to the current restriction.
This is my code:

import pyro
import pyro.distributions as dist
from pyro.infer import TraceEnum_ELBO, config_enumerate
from pyro.infer.autoguide import AutoDiscreteParallel, AutoDelta, AutoNormal
@config_enumerate
def invalid_model():
    with pyro.plate('plate', 10): # <— invalid vectorized plate
        x = pyro.sample('x', dist.Bernoulli(0.5))
    pyro.sample('obs', dist.Normal(x.sum(), 1.))


guide=AutoNormal(pyro.poutine.block(invalid_model, hide=['x']))

pyro.clear_param_store()
elbo = TraceEnum_ELBO()
print(elbo.loss(invalid_model, guide))

I believe this part of the tutorial need some modification because it gives me the impression that a enumerated vectorize-plated variable can only appear in the leaf node of a probabilistic graph, which i believe is a wrong impression.

Could someone explain this to me, i can do some modifications to this tutorial then.

By the way, i noticed that the code block appeared above always removes indentation, how to avoid this?

For the code block you have to wrap the text with triple ``` on the first line before the code block and again after the code block. For the inline code wrap the text with a single ` before and after the text.

Looking at the source code it looks like this is only checked for enumerated variables in the guide. For example using the following guide structure together with the invalid_model from the tutorial will trigger the warning:

@config_enumerate
def guide(data):
    with pyro.plate('plate', 10): # <— invalid vectorized plate
        pyro.sample('x', dist.Bernoulli(0.5))

but not for the empty guide where x is enumerated in the model:

def guide(data):
    pass

It seems that this might be clarified in the tutorial?

Do you mean the variable that depends on enumerated variables has to be in the leaf node? For example obs depends on enumerated variable x so it has to be nested in plates that x is nested in.

I will try to make some modifications to the tutorial.

I got it, so ‘obs’ has to be in the plates in which ‘x’ is also nested. In the beginning i think this restriction is to strict. But it seems that i have to live with that~

Thanks for your reply!