Confusion about plate grammar and tensor shapes

I tried the following code:

with pyro.plate("x_plate", 2):
    with pyro.plate("y_plate", 3):
        a = pyro.sample("a", Normal(0, 1))
        b = pyro.sample("b", Normal(torch.zeros(2), 1))
        c = pyro.sample("c", Normal(torch.zeros(3, 2), 1))

print(a.shape, b.shape, c.shape)

According to the tensor shape tutorial (http://pyro.ai/examples/tensor_shapes.html), it makes sense that a, b and c have the same shapes (torch.Size([3, 2])).

What I don’t understand is, are there any differences between a, b and c? Which definition should be used in what scenario?

Thank you very much!

Hi @changhu

There is no real difference from the inference viewpoint. But a can save memory compared to b and c since it calls expand([3,2]) on the distribution:

Returns a new distribution instance (or populates an existing instance provided by a derived class) with batch dimensions expanded to batch_shape. This method calls expand on the distribution’s parameters. As such, this does not allocate new memory for the expanded distribution instance.

https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution.expand

Thank you for the explanation!! :blush:

I’m still lost here – if expand dose not allocate new memory for the expanded distribution instance, why avoiding calling expand saves memory?

Also, is there an intuitive reason that different means of creating more or less the same variable are allowed? (Like is it important in dynamic model or something? Otherwise it seems a bit confusing and error-prone)

Expand is called by Pyro behind the scenes when the model is traced.

I guess it is a matter of convenience and maybe memory efficiency. PyTorch also allows two equivalent approaches with broadcasting:

torch.tensor([2., 1.]) + 0.0
torch.tensor([2., 1.]) + torch.tensor([0., 0.])

You can think of pyro.plates automatically expanding/broadcasting distribution shapes.

1 Like