 # Tensor shapes - enumeration

• What tutorial are you running?
Tensor shapes in Pyro
• What version of Pyro are you using?
1.7.0
model3
``````@config_enumerate
def model3():
p = pyro.param("p", torch.arange(6.) / 6)
locs = pyro.param("locs", torch.tensor([-1., 1.]))

a = pyro.sample("a", Categorical(torch.ones(6) / 6))
b = pyro.sample("b", Bernoulli(p[a]))  # Note this depends on a.
with pyro.plate("c_plate", 4):
c = pyro.sample("c", Bernoulli(0.3))
with pyro.plate("d_plate", 5):
d = pyro.sample("d", Bernoulli(0.4))
e_loc = locs[d.long()].unsqueeze(-1)
e_scale = torch.arange(1., 8.)
e = pyro.sample("e", Normal(e_loc, e_scale)
.to_event(1))  # Note this depends on d.

#                   enumerated|batch|event dims
assert a.shape == (         6, 1, 1   )  # Six enumerated values of the Categorical.
assert b.shape == (      2, 1, 1, 1   )  # Two enumerated Bernoullis, unexpanded.
assert c.shape == (   2, 1, 1, 1, 1   )  # Only two Bernoullis, unexpanded.
assert d.shape == (2, 1, 1, 1, 1, 1   )  # Only two Bernoullis, unexpanded.
assert e.shape == (2, 1, 1, 1, 5, 4, 7)  # This is sampled and depends on d.

assert e_loc.shape   == (2, 1, 1, 1, 1, 1, 1,)
assert e_scale.shape == (                  7,)

test_model(model3, model3, TraceEnum_ELBO(max_plate_nesting=2))
``````

Why doesn’t sample `d` shape reflect the plate dimensions under enumeration? I would expect to see the same dimensions of the plates `c_plate` and `d_plate`, 4 and 5 respectfully appear in `d.shape` yet the `assert` shows a shape of `(2, 1, 1, 1, 1, 1 )`. Same question for sample `c`.

I can find the plate size is in `trace.format_shapes()`, the full output is shown below, within the row’s `c dist`, `d dist`, and `log_proba` for each sample.

If I just read from the shapes, a enumerate sample will have all possible values enumerate with no regard to the plate size?

``````trace = poutine.trace(poutine.enum(model3, first_available_dim=-3)).get_trace()
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())
``````

Output:

``````Trace Shapes:
Param Sites:
p             6
locs             2
Sample Sites:
a dist             |
value       6 1 1 |
log_prob       6 1 1 |
b dist       6 1 1 |
value     2 1 1 1 |
log_prob     2 6 1 1 |
c_plate dist             |
value           4 |
log_prob             |
c dist           4 |
value   2 1 1 1 1 |
log_prob   2 1 1 1 4 |
d_plate dist             |
value           5 |
log_prob             |
d dist         5 4 |
value 2 1 1 1 1 1 |
log_prob 2 1 1 1 5 4 |
e dist 2 1 1 1 5 4 | 7
value 2 1 1 1 5 4 | 7
log_prob 2 1 1 1 5 4 |
``````

Why doesn’t sample `d` shape reflect the plate dimensions under enumeration?

It is unnecessary. All variables that depend on `d` should have the plates `c_plate` and `d_plate`. For each entries `(i, j)` (where `i` belongs to `c_plate`, `j` belongs to `d_plate`), there are just 2 possible values for `d[i, j]`: 0 and 1. We can then calculate density at `e[i, j]` for `d[i, j] = 0/1` respectively, then summing the results to marginalize the `d` variable.