Why is plate_stack not working for this model?

Hi devs,

I get this error with the following model:

    def _model():
        outer_plate = 3
        inner_plate = [2, 4]

        with pyro.plate("outer_plate", outer_plate):
            with pyro.plate("inner_plate_0", inner_plate[0], dim=-3):
                a = pyro.sample("a", dist.Exponential(.2))

        with pyro.plate("outer_plate", outer_plate):
            with pyro.plate_stack("inner_plate", inner_plate, rightmost_dim=-2):
                a_fraction = pyro.sample("a_fraction", dist.Beta(2, 1))
                b = pyro.deterministic("b", a_fraction * a)

Here, I want the shape of a to be (2, 1, 3) so that it multiplies properly with a_fraction. a_fraction has the shape (2, 4, 3). But I get the following error when running this model with NUTS samples. (Although, I’m able to get the trace of this model, which works fine.)

Incompatible shapes for broadcasting: shapes=[(4, 1, 1), (2, 1, 3)]

ValueError: Incompatible shapes for broadcasting: shapes=[(4, 1, 1), (2, 1, 3)]

But the model without the plate_stack works fine:

    def _model_without_plate_stack():
        outer_plate = 3
        inner_plate = [2, 4]

        with pyro.plate("outer_plate", outer_plate):
            with pyro.plate("inner_plate_0", inner_plate[0], dim=-3):
                a = pyro.sample("a", dist.Exponential(.2))

        with pyro.plate("outer_plate", outer_plate):
            # with pyro.plate_stack("inner_plate", inner_plate, rightmost_dim=-2):
            with pyro.plate("inner_plate_1", inner_plate[1]):
                with pyro.plate("inner_plate_0", inner_plate[0]):
                    a_fraction = pyro.sample("a_fraction", dist.Beta(2, 1))
                    b = pyro.deterministic("b", a_fraction * a)

Any idea why there’s a difference in these approaches?

What happens if you specify dim in the outer plate?

Hi fehiepsi, sorry for the late reply, was away for the weekend. It throws the same error. Here is a working example:

from jax import random
import numpyro as pyro
from numpyro import distributions as dist
from numpyro.infer import NUTS, MCMC


def _model():
    outer_plate = 3
    inner_plate = [2, 4]

    with pyro.plate("outer_plate", outer_plate):
        with pyro.plate("inner_plate_0", inner_plate[0], dim=-3):
            a = pyro.sample("a", dist.Exponential(.2))

    with pyro.plate("outer_plate", outer_plate):
        with pyro.plate_stack("inner_plate", inner_plate, rightmost_dim=-2):
            a_fraction = pyro.sample("a_fraction", dist.Beta(2, 1))
            b = pyro.deterministic("b", a_fraction * a)


def _model_without_plate_stack():
    outer_plate = 3
    inner_plate = [2, 4]

    with pyro.plate("outer_plate", outer_plate):
        with pyro.plate("inner_plate_0", inner_plate[0], dim=-3):
            a = pyro.sample("a", dist.Exponential(.2))

    with pyro.plate("outer_plate", outer_plate):
        # with pyro.plate_stack("inner_plate", inner_plate, rightmost_dim=-2):
        with pyro.plate("inner_plate_1", inner_plate[1]):
            with pyro.plate("inner_plate_0", inner_plate[0]):
                a_fraction = pyro.sample("a_fraction", dist.Beta(2, 1))
                b = pyro.deterministic("b", a_fraction * a)


def _model_03():
    outer_plate = 3
    inner_plate = [2, 4]

    with pyro.plate("outer_plate", outer_plate, dim=-1):
        with pyro.plate("inner_plate_0", inner_plate[0], dim=-3):
            a = pyro.sample("a", dist.Exponential(.2))

    with pyro.plate("outer_plate", outer_plate, dim=-1):
        with pyro.plate_stack("inner_plate", inner_plate, rightmost_dim=-2):
            a_fraction = pyro.sample("a_fraction", dist.Beta(2, 1))
            b = pyro.deterministic("b", a_fraction * a)


def trace(key, model):
    with pyro.handlers.seed(rng_seed=key):
        trace = pyro.handlers.trace(model).get_trace()
    print(f"Trace successful")
    return trace


def run(key, model):
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_samples=100, num_warmup=100)
    mcmc.run(key)
    print(f"Run successful")
    return mcmc
    

key = random.key(0)

trace01 = trace(key, _model) # works
mcmc01 = run(key, _model) # fails

trace02 = trace(key, _model_without_plate_stack) # works
mcmc02 = run(key, _model_without_plate_stack) # works

trace03 = trace(key, _model_03) # works
mcmc03 = run(key, _model_03) # fails

Thanks @mathlad! I think it is a bug. Could you create a github issue for this?

ok.