Elbo.compute_marginals throws an error for basic model

Hi,

Consider a dummy example:

@config_enumerate()
def model(a=None, b=None):
    a = sample("a", dist.Bernoulli(0.75), obs=a)
    p = 1-0.25*a
    b = sample("b", dist.Bernoulli(p), obs=b)

Can anyone help me to understand why two lines below are working as expected:

elbo = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1)
margin = elbo.compute_marginals(model, guide, b=tt(1.))

while:

margin = elbo.compute_marginals(model, guide, a=tt(1.))
margin = elbo.compute_marginals(model, guide)

fails with error:

ValueError: Number of einsum subscripts must be equal to the number of operands.

How to get elbo.compute_marginals(model, guide, a=tt(1.)) running?

Cheers!

1 Like

Hi @Bart, I believe this error is caused by a bug in compute_marginals related to an unhandled edge case (variables with no enumerated ancestors or observed descendants).
Can you open a GitHub issue with a runnable snippet and a full traceback so that we can fix it for you?

Hi,

Thanks for a confirmation. I will double check and open a GitHub issue.

1 Like

Hi, Is there any convenient way to avoid this error, maybe like adding the dummy descendants?

Hi @lyy, this should now be fixed on Pyro’s dev branch. I’d recommend installing Pyro from dev rather than a workaround. Thanks for reporting the bug.

1 Like