Indexing with enumerated variables

Hi, I have a simple model where I sample a Gaussian variable x with shape [3, 30, 100] and a Categorical variable q with shape [100] and values in {0, 1, 2}. I want to use q to index x, such that for each of the 100 elements in the -1 dimension of x, the element in the -3 dimension is selected according to the categorical value in q, resulting in a tensor y of shape [30, 100]. Without enumeration, I could achieve this with x[q, :, torch.arange(100)]. But I don’t understand how this would work with enumeration, when q has shape [3, 1, 1, 1]. Any advice?

@config_enumerate
def model():
    plate_1 = pyro.plate('plate_1', 100, dim=-1)
    plate_2 = pyro.plate('plate_2', 30, dim=-2)
    plate_3 = pyro.plate('plate_3', 3, dim=-3)

    with plate_3, plate_2, plate_1:
        x = pyro.sample('x', dist.Normal(0, 1))

    with plate_1:
        q = pyro.sample('q', dist.Categorical(torch.ones(3)))

    # works without enumeration
    y = x[q, :, torch.arange(x.shape[-1])].T

def guide():
    pass

elbo = TraceEnum_ELBO(max_plate_nesting=3)
elbo.loss(model, guide)

You cannot do that with enumeration because the point of enumeration is that you don’t sample a specific value but enumerate all possible values of x (i.e. {0, 1, 2}). Then these enumerated values are used by the TraceEnum_ELBO to marginalize (sum) over x in model log-density.

You can use Vindex for indexing and you should end up with the shape [3,1,30,100] for y.

Thanks for your answer, I am aware that the shape of y won’t be [30, 100] when doing enumeration. Could you provide an example of how to use Vindex in this case?

If I do:

@config_enumerate
def model():
    plate_1 = pyro.plate('plate_1', 100, dim=-1)
    plate_2 = pyro.plate('plate_2', 30, dim=-2)
    plate_3 = pyro.plate('plate_3', 3, dim=-3)

    with plate_3, plate_2, plate_1:
        x = pyro.sample('x', dist.Normal(0, 1))

    with plate_1:
        q = pyro.sample('q', dist.Categorical(torch.ones(3)))

    y = Vindex(x)[q, :, torch.arange(x.shape[-1])].transpose(-1, -2)

    print('x', x.shape)
    print('q', q.shape)
    print('y', y.shape)

def guide():
    pass

print('sampling without enumeration:')
model()

print('\nsampling with enumeration:')
elbo = TraceEnum_ELBO(max_plate_nesting=3)
elbo.loss(model, guide)

I get

sampling without enumeration:
x torch.Size([3, 30, 100])
q torch.Size([100])
y torch.Size([30, 100])

sampling with enumeration:
x torch.Size([3, 30, 100])
q torch.Size([3, 1, 1, 1])
y torch.Size([3, 1, 1, 30, 100])

So would I have to do something like Vindex(x)[q.squeeze(-1), :, torch.arange(x.shape[-1])] to achieve shape [3, 1, 30, 100] for y? That doesn’t look right to me somehow

You can index it like this:

    with plate_1:
        q = pyro.sample('q', dist.Categorical(torch.ones(3)))

    with plate_3, plate_2 as jdx, plate_1 as idx:
        x = pyro.sample('x', dist.Normal(0, 1))
        y = Vindex(x)[q, jdx[:, None], idx]

Or a better approach in my opinion is to make plate_3 and event dim by using to_event so that it is on the right-hand side. This makes Vindexing easier:

    with plate_1:
        q = pyro.sample('q', dist.Categorical(torch.ones(3)))

    with plate_2, plate_1:
        x = pyro.sample('x', dist.Normal(0, 1).expand([3]).to_event(1))
        y = Vindex(x)[..., q]

...
elbo = TraceEnum_ELBO(max_plate_nesting=2)
2 Likes