TraceEnum_ELBO parallel enumeration fails with error about cartesian products

I have a model that uses a 2-D tensor of Bernoulli values, the first dimension is independent and parallelized, but it fails with

NotImplementedError: Pyro does not enumerate over cartesian products

Is there a way around this?

Here’s a small code example that repros the error:

def model_and_guide():
    with pyro.iarange('test', 2):
        values = pyro.sample('limit', dist.Bernoulli(logits=torch.Tensor([[1,2,3], [4,5,6]])).independent(1))

pyro.clear_param_store()
optimizer = Adam({"lr": 0.005, "betas": (0.95, 0.999)})
svi = SVI(model_and_guide, config_enumerate(model_and_guide, default='parallel'), optimizer, loss=TraceEnum_ELBO(max_iarange_nesting=1))
svi.step()

Thanks,

Ravi

Great question! There are a few ways to address this.

  1. If your variables are conditionally independent (as in your trivial example), you can replace the .independent(1) with a second iarange:

    @config_enumerate(default="parallel")
    def model_and_guide():
        logits = torch.tensor([[1.,2.,3.], [4.,5.,6.]])
        with pyro.iarange('x_axis', 3), pyro.iarange('y_axis', 2):
            values = pyro.sample('limit', dist.Bernoulli(logits=logits))
    
  2. If your variables are not conditionally independent (e.g. if a downstream observe statement combines values among either iarange), then you have two options.
    The cleaner option is to convert one parallel iarange into a sequential irange:

    @config_enumerate(default="parallel")
    def model_and_guide():
        logits = torch.tensor([[1.,2.,3.], [4.,5.,6.]])
        with pyro.iarange('y_axis', 2):
            values = []
            for i in pyro.irange('x_axis', 3):
                value = pyro.sample('limit_{}'.format(i),
                                    dist.Bernoulli(logits=logits[:, i]))
                values.append(value)
            # This observe statement forces us to sequentially sample along x_axis:
            pyro.sample('aggregate', dist.Binomial(3, 0.1), obs=sum(values))
    

    The second option is to manually compute the cartesian product by replacing the Bernoulli with a Categorical over the product space; this is kind of gross.

1 Like

Thanks for the quick response! I tried option 2 since they are dependent, and something very strange is going on.

def model_and_guide():
    logits = torch.Tensor([[1,2,3], [4,5,6]])
    with pyro.iarange('test', 2):
        values = []
        for i in pyro.irange('test2', 3):
            s = pyro.sample('limit_%d' % i, dist.Bernoulli(logits=logits[:,i]))
            print logits[:,i].shape, s.shape
            values.append(s)
        values = torch.stack(values, dim=1)

Each iteration is inserting dimensions, the output is:

torch.Size([2]) torch.Size([2, 2])
torch.Size([2]) torch.Size([2, 1, 2])
torch.Size([2]) torch.Size([2, 1, 1, 2])

Yes, that is to be expected: each new variable needs to be enumerated in a different dimension. Broadcasted together they form a cartesian product (whose volume grows exponentially with the number of variables, as expected). Often you can work with these differently-shaped tensors efficiently without broadcasting. I’d recommend reading (or re-reading) the Tensor Shapes Tutorial to help understand how Pyro’s enumeration works.

I’ve read the Tensor Shapes Tutorial several times already, I’ll go through it one more time :-). I’m still trying to develop a good mental model for what Pyro is doing. Is there a good paper that walks through the basic algorithms?

In this case, isn’t the point of declaring independence to avoid doing the cartesian product? i.e. you only need two parallel searches of R^3 instead of a cartesian search of R^3 x R^3

Is there a good paper

We’re working on a tutorial and tech report about enumeration. Let us know what you’d like to see in those :smile:

isn’t the point of declaring independence to avoid doing the cartesian product?

Yes. I’m happy to explain if you post a more complete model, ideally a complete model, guide, and training loop. There are different tricks to avoid the cartesian product, and to recommend a trick we’ll need to se how you consume values downstream of the sample statements.