Inference result depends on whether `sequential`or `parallel`is used as enumeration strategy

Hi!

I am a bit confused about the enumeration specs.

import pyro
import pyro.distributions as dist
import torch
from pyro.infer import config_enumerate
from pyro.infer import infer_discrete

@config_enumerate
def model(x_pa_obs=None, x_ch_obs=None, y_obs=None):
    p = x_pa_obs
    y = pyro.sample('y_pre', dist.Binomial(probs=p),
                        infer={"enumerate": "sequential"},
                        obs=y_obs)

    d_ch = dist.Normal(y, 1.0)
    x_ch_pre = pyro.sample('x_ch_pre', d_ch, obs=x_ch_obs)

    return y

data_obs = {'x_pa_obs': torch.tensor(0.5), 'x_ch_obs': torch.tensor(1.0)}
model_discrete = infer_discrete(model, first_available_dim=-1, temperature=1)

y_posts = []
for ii in range(10**4):
    print(f'iteration {ii}', end='\r')
    y_posts.append(model_discrete(**data_obs))

smpl = torch.stack(y_posts)
print(f"mean: {smpl.mean()}")

When I use parallelas enumeration strategy, the expected probability p(y| x_ch, x_pa) is inferred (0.625). However, if I use sequential inference returns wrong results (0.503). Is there a qualitative difference between the two methods, or is parallelonly faster? Is there a problem with my code?

In your experience – how accurate is discrete inference with pyro? How would you suggest to go about inferring said probability?

Best,
Gunnar

Hi @gcskoenig, the two inference algorithms should be equivalent, I’m not sure why you’re seeing different probabilities.

Aside: I’m surprised inference with your model works at all since you’re not passing total_count to the Binomial distribution, so during enumeration it is treated as a Bernoulli. You might consider passing in total_count.

Ok, wierd! Especially since it is such a basic example.

Regarding the model specification, a bernoulli is exactly what I was interested in. I added total_count=1, which does not affect the result (and the disparity between parallel and sequential enumeration).

I will open an issue on Github then.