# Indexing with enumerated variables

Hi, I have a simple model where I sample a Gaussian variable x with shape [3, 30, 100] and a Categorical variable q with shape  and values in {0, 1, 2}. I want to use q to index x, such that for each of the 100 elements in the -1 dimension of x, the element in the -3 dimension is selected according to the categorical value in q, resulting in a tensor y of shape [30, 100]. Without enumeration, I could achieve this with x[q, :, torch.arange(100)]. But I don’t understand how this would work with enumeration, when q has shape [3, 1, 1, 1]. Any advice?

``````@config_enumerate
def model():
plate_1 = pyro.plate('plate_1', 100, dim=-1)
plate_2 = pyro.plate('plate_2', 30, dim=-2)
plate_3 = pyro.plate('plate_3', 3, dim=-3)

with plate_3, plate_2, plate_1:
x = pyro.sample('x', dist.Normal(0, 1))

with plate_1:
q = pyro.sample('q', dist.Categorical(torch.ones(3)))

# works without enumeration
y = x[q, :, torch.arange(x.shape[-1])].T

def guide():
pass

elbo = TraceEnum_ELBO(max_plate_nesting=3)
elbo.loss(model, guide)
``````

You cannot do that with enumeration because the point of enumeration is that you don’t sample a specific value but enumerate all possible values of x (i.e. {0, 1, 2}). Then these enumerated values are used by the `TraceEnum_ELBO` to marginalize (sum) over x in model log-density.

You can use Vindex for indexing and you should end up with the shape [3,1,30,100] for y.

Thanks for your answer, I am aware that the shape of `y` won’t be `[30, 100]` when doing enumeration. Could you provide an example of how to use `Vindex` in this case?

If I do:

``````@config_enumerate
def model():
plate_1 = pyro.plate('plate_1', 100, dim=-1)
plate_2 = pyro.plate('plate_2', 30, dim=-2)
plate_3 = pyro.plate('plate_3', 3, dim=-3)

with plate_3, plate_2, plate_1:
x = pyro.sample('x', dist.Normal(0, 1))

with plate_1:
q = pyro.sample('q', dist.Categorical(torch.ones(3)))

y = Vindex(x)[q, :, torch.arange(x.shape[-1])].transpose(-1, -2)

print('x', x.shape)
print('q', q.shape)
print('y', y.shape)

def guide():
pass

print('sampling without enumeration:')
model()

print('\nsampling with enumeration:')
elbo = TraceEnum_ELBO(max_plate_nesting=3)
elbo.loss(model, guide)
``````

I get

``````sampling without enumeration:
x torch.Size([3, 30, 100])
q torch.Size()
y torch.Size([30, 100])

sampling with enumeration:
x torch.Size([3, 30, 100])
q torch.Size([3, 1, 1, 1])
y torch.Size([3, 1, 1, 30, 100])
``````

So would I have to do something like `Vindex(x)[q.squeeze(-1), :, torch.arange(x.shape[-1])]` to achieve shape `[3, 1, 30, 100]` for y? That doesn’t look right to me somehow

You can index it like this:

``````    with plate_1:
q = pyro.sample('q', dist.Categorical(torch.ones(3)))

with plate_3, plate_2 as jdx, plate_1 as idx:
x = pyro.sample('x', dist.Normal(0, 1))
y = Vindex(x)[q, jdx[:, None], idx]
``````

Or a better approach in my opinion is to make `plate_3` and event dim by using `to_event` so that it is on the right-hand side. This makes Vindexing easier:

``````    with plate_1:
q = pyro.sample('q', dist.Categorical(torch.ones(3)))

with plate_2, plate_1:
x = pyro.sample('x', dist.Normal(0, 1).expand().to_event(1))
y = Vindex(x)[..., q]

...
elbo = TraceEnum_ELBO(max_plate_nesting=2)
``````
2 Likes