Hi all. I have a model which is a hierarchy of GLMs where I want to test a null vs full model for each sub-GLM (there are params shared between these GLMs in the real version although not in the MWE below). Possibly I’m hitting this bug, possibly I’m doing something dumb (edit: I’ve tried the fix to issue 2860 to no avail). The error I get is
ValueError: Error while packing tensors at site 'obs':
Invalid tensor shape.
Allowed dims: -2, -1
Actual shape: (2, 10, 100)
Full MWE:
import torch
import pyro
import pyro.distributions as dist
from pyro.infer.autoguide import AutoDiagonalNormal, AutoGuideList, AutoDelta, init_to_value
from pyro.infer import SVI, Trace_ELBO, config_enumerate, infer_discrete
from pyro.ops.indexing import Vindex
from pyro import poutine
@config_enumerate
def model(x_null, x_full, y, n):
P_full = x_full.shape[1]
P_null = x_null.shape[1]
N, J = y.shape
device = x_full.device
with pyro.plate("P_null", P_null): # beta is P (covariates) x J (junctions)
beta_null = pyro.sample("beta_null", dist.Normal(0., 1.).expand([P_null, J]).to_event(1))
with pyro.plate("P_full", P_full): # beta is P (covariates) x J (junctions)
beta_full = pyro.sample("beta_full", dist.Normal(0., 1.).expand([P_full, J]).to_event(1))
logits_full = x_full @ beta_full # logits is N x J
logits_null = x_null @ beta_null
logits_combined = torch.stack((logits_null,logits_full)).transpose(1,2) # 2 x J x N
with pyro.plate("data", J): # over individual regressions
conc_param = pyro.sample("conc", dist.Gamma(2.,.2))
assignment = pyro.sample(
'assignment',
dist.Bernoulli(0.1)
).long()
J_arange = torch.arange(J, device = device)
logits = Vindex(logits_combined)[assignment,J_arange,:].transpose(-1,-2) # 2 x N x J
#logits = logits_combined[assignment,J_arange].transpose(-1,-2)
p = logits.sigmoid()
a = p * conc_param + 1e-8
b = (1.-p) * conc_param + 1e-8
pyro.sample("obs", dist.BetaBinomial(a, b, total_count = n), obs=y)
# simulate data
N = 10
P = 5
J = 100
conc = 10.
x_null = torch.randn((N,P))
b = torch.randn((P,J))
g = (x_null @ b).sigmoid()
n = dist.Poisson(50).sample([N,J])
y = dist.BetaBinomial(g * conc, (1.-g)*conc, total_count = n).sample()
x_full = torch.concat([x_null,torch.randn(N,1)],1)
# attempt to calculate ELBO
guide = AutoDiagonalNormal(poutine.block(model, hide = ['assignment']))
elbo_func = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1)
print(elbo_func(model, guide)(x_null, x_full, y, n))