Hi devs,
I get this error with the following model:
def _model():
outer_plate = 3
inner_plate = [2, 4]
with pyro.plate("outer_plate", outer_plate):
with pyro.plate("inner_plate_0", inner_plate[0], dim=-3):
a = pyro.sample("a", dist.Exponential(.2))
with pyro.plate("outer_plate", outer_plate):
with pyro.plate_stack("inner_plate", inner_plate, rightmost_dim=-2):
a_fraction = pyro.sample("a_fraction", dist.Beta(2, 1))
b = pyro.deterministic("b", a_fraction * a)
Here, I want the shape of a
to be (2, 1, 3) so that it multiplies properly with a_fraction
. a_fraction
has the shape (2, 4, 3). But I get the following error when running this model with NUTS samples. (Although, I’m able to get the trace of this model, which works fine.)
Incompatible shapes for broadcasting: shapes=[(4, 1, 1), (2, 1, 3)]
ValueError: Incompatible shapes for broadcasting: shapes=[(4, 1, 1), (2, 1, 3)]
But the model without the plate_stack works fine:
def _model_without_plate_stack():
outer_plate = 3
inner_plate = [2, 4]
with pyro.plate("outer_plate", outer_plate):
with pyro.plate("inner_plate_0", inner_plate[0], dim=-3):
a = pyro.sample("a", dist.Exponential(.2))
with pyro.plate("outer_plate", outer_plate):
# with pyro.plate_stack("inner_plate", inner_plate, rightmost_dim=-2):
with pyro.plate("inner_plate_1", inner_plate[1]):
with pyro.plate("inner_plate_0", inner_plate[0]):
a_fraction = pyro.sample("a_fraction", dist.Beta(2, 1))
b = pyro.deterministic("b", a_fraction * a)
Any idea why there’s a difference in these approaches?