Categorical and enumeration

I’m trying to model a system where a Categorical determines which pool an item goes into, and then it produces a random number of observations. The 3.0 tutorials appear to show the enumerated values do a Cartesian product with the observations, but this doesn’t seem to extend when each choice produces multiple data values. Here is a simplified model:

import torch
from torch import Tensor as TT
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam

simple_data = TT([[1., 3.], [10., 3.], [1., 8.], [9., 4.], [7., 1.]]).transpose(0, 1)
n_pools = simple_data.shape[0]
n_item = simple_data.shape[1]

@config_enumerate
def test_model(data):
    pool_size = pyro.param('pool_size', dist.LogNormal(TT([3]), TT([1])).sample())
    with pyro.plate('item', n_item):
        pool = pyro.sample('pool', dist.Categorical(torch.ones(n_pools)))
        rates = (torch.eye(n_pools)[pool] * 0.5 + 0.1) * pool_size
        if data is not None:
            obs = pyro.sample('obs', dist.Poisson(rates), obs=data)

def test_guide(data):
    test_model(None)

pyro.clear_param_store()
optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=0)
svi = SVI(test_model, test_guide, optim, loss=elbo)
svi.step(simple_data)

This produces the following error:

ValueError: Shape mismatch inside plate('item') at site obs dim -1, 5 vs 2
Trace Shapes:    
 Param Sites:    
    pool_size   1
Sample Sites:    
    item dist   |
        value 5 |
    pool dist 5 |
        value 2 |

I’ve tried many variations and have not been able to find something that works. Thanks for your help.

Ravi

Hi @rpandya, could you explain a little more what your model is intended to do? For example when you say “random number of observations” do you mean that a random number of observations are measured (each producing an observed value), or do you mean that a single measurement is made with a single observed value that is a random number?) I don’t fully understand the model, but here are some superficial comments:

  1. You have one plate, so set max_plate_nesting=1.

  2. The shape error is at your pyro.sample('obs', ...) site, where rates will have shape (5,), the plate has shape (2,), and data has shape (5,2). Setting max_plate_nesting=1 will partially fix this, but I think you’ll also need to slice into data using something like obs=data[pool].

  3. The test_guide doesn’t make sense to me. If I understand correctly, you are trying to learn a pool_size parameter under a prior LogNormal(3., 1.). If this is the case, then I would instead suggest using a sample statement in the model and an autoguide:

    def test_model():
        pool_size = pyro.sample("pool_size", dist.LogNormal(3., 1.))
        ...
    
    test_guide = AutoDiagonalNormal(poutine.block(test_model, expose=["pool_size"]))
    

Thanks @fritzo! Sorry for the lack of clarity, I meant a single measurement of the pool values for each item, and then the distribution in each pool depends on the categorical for the item. Your other comments were helpful, and got a bit farther but then the dimensionality of obs still didn’t match:

simple_data = TT([[1., 3.], [10., 3.], [1., 8.], [9., 4.], [7., 1.]])
n_pools = simple_data.shape[1]
n_item = simple_data.shape[0]

@config_enumerate
def test_model(data):
    pool_size = pyro.sample('pool_size', dist.LogNormal(3., 1.))
    print('pool_size', pool_size)
    with pyro.plate('item', n_item) as items:
        print('items', items)
        pool = pyro.sample('pool', dist.Categorical(torch.ones(n_pools)))
        print('pool', pool, pool.shape)
        rates = (torch.eye(n_pools)[pool] * 0.5 + 0.1) * pool_size
        print('rates', rates, rates.shape)
        if len(pool.shape) > 1: # hack to skip obs in guide
            print('data[pool]', data[pool], data[pool].shape)
            obs = pyro.sample('obs', dist.Poisson(rates), obs=data[pool])
            print('obs', obs, obs.shape)

test_guide = AutoDiagonalNormal(poutine.block(test_model, expose=['pool_size']))

pyro.clear_param_store()
optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(test_model, test_guide, optim, loss=elbo)
svi.step(simple_data)

# output

('pool_size', tensor(58.5516))
('items', tensor([0, 1, 2, 3, 4]))
('pool', tensor([0, 1, 0, 1, 1]), torch.Size([5]))
('rates', tensor([[35.1310,  5.8552],
        [ 5.8552, 35.1310],
        [35.1310,  5.8552],
        [ 5.8552, 35.1310],
        [ 5.8552, 35.1310]]), torch.Size([5, 2]))
('pool_size', tensor(0.6469, grad_fn=<ExpandBackward>))
('items', tensor([0, 1, 2, 3, 4]))
('pool', tensor([[0],
        [1]]), torch.Size([2, 1]))
('rates', tensor([[[0.3881, 0.0647]],

        [[0.0647, 0.3881]]], grad_fn=<MulBackward0>), torch.Size([2, 1, 2]))
# >> this is not right >>
('data[pool]', tensor([[[ 1.,  3.]],

        [[10.,  3.]]]), torch.Size([2, 1, 2]))

So I tried just enumerating the items sequentially and this worked

@config_enumerate
def test_model(data):
    pool_size = pyro.sample('pool_size', dist.LogNormal(3., 1.))
    print('pool_size', pool_size)
    for i in pyro.plate('item', n_item):
        pool = pyro.sample('pool_%i' % i, dist.Categorical(torch.ones(n_pools)))
        print('pool', pool, pool.shape)
        rates = (torch.eye(n_pools)[pool] * 0.5 + 0.1) * pool_size
        print('rates', rates, rates.shape)
        if len(pool.shape) > 1: # hack to skip obs in guide
            print('data[pool]', data[pool], data[pool].shape)
            obs = pyro.sample('obs_%d' % i, dist.Poisson(rates).to_event(1), obs=data[i])
            print('obs', obs, obs.shape)

test_guide = AutoDiagonalNormal(poutine.block(test_model, expose=['pool_size'], hide=['obs']))

pyro.clear_param_store()
optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(test_model, test_guide, optim, loss=elbo)
svi.step(simple_data)

This didn’t give any errors, but there were some strange things. The pool shape kept adding dimensions, though IIUC iterating the plate should not add dimensions since they’re conditionally independent?

('pool', tensor([[0],
        [1]]), torch.Size([2, 1]))
# keeps getting deeper with each iteration
('pool', tensor([[[[[[0]]]]],
        [[[[[1]]]]]]), torch.Size([2, 1, 1, 1, 1, 1]))

I tried running a hundred steps to optimize loss, then using the code from the Gaussian example using infer_discrete to pull out the pool assignments, but they were random AFAICT - what drives the choice of optimal assignments there?

Thanks for your help, though - this is progress!

Ravi

That’s simply because Pyro’s enumeration logic is not yet smart enough to understand plate iterators. In your model you can instead use pyro.markov(history=0), which should lead to the same enumeration dim being used each iteration:

@config_enumerate
def test_model(data):
    pool_size = pyro.sample('pool_size', dist.LogNormal(3., 1.))
    for i in pyro.markov(range(n_item), history=0):
        ...

FYI, I did finally get the parallel version working. For the record in case someone else has a similar problem, the (totally non-obvious!) trick was to replicate the observed data for each enumeration and add a dimension for the plate, i.e.

data_obs = TT([data.numpy()])[pool * 0]
print('data_obs', data_obs, data_obs.shape)
obs = pyro.sample('obs', dist.Poisson(rates).to_event(1), obs=data[pool])
print('obs', obs, obs.shape)

('data_obs', tensor([[[[ 1.,  3.],
          [10.,  3.],
          [ 1.,  8.],
          [ 9.,  4.],
          [ 7.,  1.]]],
        [[[ 1.,  3.],
          [10.,  3.],
          [ 1.,  8.],
          [ 9.,  4.],
          [ 7.,  1.]]]]), torch.Size([2, 1, 5, 2]))
('obs', tensor([[[ 1.,  3.]],
        [[10.,  3.]]]), torch.Size([2, 1, 2]))

One more question: Is there a way to get likelihoods or posteriors out of infer_discrete? I’m getting some results where each run it randomly chooses one of several equivalently good discrete options, which I’d like to discount relative to cases where there is a clear best choice.

Thanks,

Ravi