Error with enumeration for hierarchy of GLMs

Hi all. I have a model which is a hierarchy of GLMs where I want to test a null vs full model for each sub-GLM (there are params shared between these GLMs in the real version although not in the MWE below). Possibly I’m hitting this bug, possibly I’m doing something dumb (edit: I’ve tried the fix to issue 2860 to no avail). The error I get is

ValueError: Error while packing tensors at site 'obs':
  Invalid tensor shape.
  Allowed dims: -2, -1
  Actual shape: (2, 10, 100)

Full MWE:

import torch
import pyro
import pyro.distributions as dist
from pyro.infer.autoguide import AutoDiagonalNormal, AutoGuideList, AutoDelta, init_to_value
from pyro.infer import SVI, Trace_ELBO, config_enumerate, infer_discrete
from pyro.ops.indexing import Vindex
from pyro import poutine

@config_enumerate
def model(x_null, x_full, y, n): 

    P_full = x_full.shape[1]
    P_null = x_null.shape[1]
    N, J = y.shape
    device = x_full.device

    with pyro.plate("P_null", P_null): # beta is P (covariates) x J (junctions)
        beta_null = pyro.sample("beta_null", dist.Normal(0., 1.).expand([P_null, J]).to_event(1)) 

    with pyro.plate("P_full", P_full): # beta is P (covariates) x J (junctions)
        beta_full = pyro.sample("beta_full", dist.Normal(0., 1.).expand([P_full, J]).to_event(1)) 

    logits_full = x_full @ beta_full # logits is N x J
    logits_null = x_null @ beta_null

    logits_combined = torch.stack((logits_null,logits_full)).transpose(1,2) # 2 x J x N

    with pyro.plate("data", J): # over individual regressions
        conc_param = pyro.sample("conc", dist.Gamma(2.,.2))
        assignment = pyro.sample(
            'assignment', 
            dist.Bernoulli(0.1)
        ).long()
        J_arange = torch.arange(J, device = device)
        logits = Vindex(logits_combined)[assignment,J_arange,:].transpose(-1,-2) # 2 x N x J
        #logits = logits_combined[assignment,J_arange].transpose(-1,-2)
        p = logits.sigmoid()
        a = p * conc_param + 1e-8
        b = (1.-p) * conc_param + 1e-8
        pyro.sample("obs", dist.BetaBinomial(a, b, total_count = n), obs=y)

# simulate data
N = 10
P = 5
J = 100
conc = 10.
x_null = torch.randn((N,P))
b = torch.randn((P,J))
g = (x_null @ b).sigmoid()
n = dist.Poisson(50).sample([N,J])
y = dist.BetaBinomial(g * conc, (1.-g)*conc, total_count = n).sample()
x_full = torch.concat([x_null,torch.randn(N,1)],1)

# attempt to calculate ELBO
guide = AutoDiagonalNormal(poutine.block(model, hide = ['assignment']))
elbo_func = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1)
print(elbo_func(model, guide)(x_null, x_full, y, n))