Tensor shapes - enumeration

  • What tutorial are you running?
    Tensor shapes in Pyro
  • What version of Pyro are you using?
    1.7.0
  • Please link or paste relevant code, and steps to reproduce.
model3
@config_enumerate
def model3():
    p = pyro.param("p", torch.arange(6.) / 6)
    locs = pyro.param("locs", torch.tensor([-1., 1.]))

    a = pyro.sample("a", Categorical(torch.ones(6) / 6))
    b = pyro.sample("b", Bernoulli(p[a]))  # Note this depends on a.
    with pyro.plate("c_plate", 4):
        c = pyro.sample("c", Bernoulli(0.3))
        with pyro.plate("d_plate", 5):
            d = pyro.sample("d", Bernoulli(0.4))
            e_loc = locs[d.long()].unsqueeze(-1)
            e_scale = torch.arange(1., 8.)
            e = pyro.sample("e", Normal(e_loc, e_scale)
                            .to_event(1))  # Note this depends on d.

    #                   enumerated|batch|event dims
    assert a.shape == (         6, 1, 1   )  # Six enumerated values of the Categorical.
    assert b.shape == (      2, 1, 1, 1   )  # Two enumerated Bernoullis, unexpanded.
    assert c.shape == (   2, 1, 1, 1, 1   )  # Only two Bernoullis, unexpanded.
    assert d.shape == (2, 1, 1, 1, 1, 1   )  # Only two Bernoullis, unexpanded.
    assert e.shape == (2, 1, 1, 1, 5, 4, 7)  # This is sampled and depends on d.

    assert e_loc.shape   == (2, 1, 1, 1, 1, 1, 1,)
    assert e_scale.shape == (                  7,)

test_model(model3, model3, TraceEnum_ELBO(max_plate_nesting=2))

Why doesn’t sample d shape reflect the plate dimensions under enumeration? I would expect to see the same dimensions of the plates c_plate and d_plate, 4 and 5 respectfully appear in d.shape yet the assert shows a shape of (2, 1, 1, 1, 1, 1 ). Same question for sample c.

I can find the plate size is in trace.format_shapes(), the full output is shown below, within the row’s c dist, d dist, and log_proba for each sample.

If I just read from the shapes, a enumerate sample will have all possible values enumerate with no regard to the plate size?

trace = poutine.trace(poutine.enum(model3, first_available_dim=-3)).get_trace()
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())

Output:

Trace Shapes:
 Param Sites:
            p             6
         locs             2
Sample Sites:
       a dist             |
        value       6 1 1 |
     log_prob       6 1 1 |
       b dist       6 1 1 |
        value     2 1 1 1 |
     log_prob     2 6 1 1 |
 c_plate dist             |
        value           4 |
     log_prob             |
       c dist           4 |
        value   2 1 1 1 1 |
     log_prob   2 1 1 1 4 |
 d_plate dist             |
        value           5 |
     log_prob             |
       d dist         5 4 |
        value 2 1 1 1 1 1 |
     log_prob 2 1 1 1 5 4 |
       e dist 2 1 1 1 5 4 | 7
        value 2 1 1 1 5 4 | 7
     log_prob 2 1 1 1 5 4 |

Why doesn’t sample d shape reflect the plate dimensions under enumeration?

It is unnecessary. All variables that depend on d should have the plates c_plate and d_plate. For each entries (i, j) (where i belongs to c_plate, j belongs to d_plate), there are just 2 possible values for d[i, j]: 0 and 1. We can then calculate density at e[i, j] for d[i, j] = 0/1 respectively, then summing the results to marginalize the d variable.