Subsampling pyro.plate Results in Shape Mismatch

I’m trying to implement subsampling the ELBO computation for the observation likelihood in the last plate in my model below, but it keeps giving me an error of ‘ValueError: Shape mismatch inside plate(‘data’) at site y dim -1, 10000 vs 232448’.

I looked into how to subsample and it looks like I’m following what was prescribed… I’m assuming pyro.plate finds the indices to subsample and makes the shapes appropriate when I pass index_select(), so I’m not sure what the issue is? The model runs fine without subsampling, so I’m pretty sure it’s not a model definition issue. Any clues on what could be wrong in the below?

def horseshoe_classification2(data, y):
  N, feature_dim = data.shape
  # - Horseshoe prior.
  # Global shrinkage parameter (tau) multiplies a coefficient's local shrinkage parameter (lambda_i) to compute the
  # standard deviation of the Gaussian that generates that coefficient (beta_i).
  tau = pyro.sample('tau', dist.HalfCauchy(torch.ones(1)))
  # Local shrinkage parameter (lambda_i) represents the local component of the standard deviation of the Gaussian that
  # generates that coefficient (beta_i).
  with pyro.plate('lambdas_plate', feature_dim):
    lambdas = pyro.sample('lambdas', dist.HalfCauchy(torch.ones(feature_dim)))
  # The horseshoe prior assumes each coefficient (beta_i) is conditionally independent.
  with pyro.plate('beta_plate', feature_dim):
    # Reparameterize to improve posterior geometry (not specific to horseshoe regression).
    with poutine.reparam(config={'betas': LocScaleReparam()}):
      betas = pyro.sample('betas', dist.Normal(0, tau * lambdas))
  # Kappa_i is roughly interpreted as the amount of weight that the posterior mean for beta_i places on 0 after the data
  # has been observed (this interpretation is primarily for regression when sigma^2 and tau^2 both equal 1).
  pyro.deterministic('kappas', 1/(1 + lambdas**2))
  # - Intercept prior.
  intercept = pyro.sample('intercept', dist.Normal(0, 10))
  # - Linear model.
  p = data @ betas + intercept
  # - Likelihood.
  with pyro.plate('data', size=N, subsample_size=10000) as ind:
    pyro.sample('y', dist.Bernoulli(logits=p), obs=y.index_select(0, ind))

# Set up training objects.
guide = AutoDiagonalNormal(horseshoe_classification2)
svi = SVI(horseshoe_classification2, guide, optim.Adam({"lr": 0.01}), loss=Trace_ELBO())

# Train model.
num_iters = 5000
for i in range(num_iters):
  elbo = svi.step(X_train, y_train)
  if i % 500 == 0:
    print("Elbo loss: {}".format(elbo))

you need to subsample everywhere consistently. perhaps you need something like

p = data[ind] ...

with possibly other changes elsewhere

Ah ok, I see… so the p wasn’t being subsampled as well. Instead of creating another plate for the p, I tried to apply an index_select to the p also, and that seemed to work (the model is converging to a similar loss as the non-subsampled version). So I just changed the last line in the model to:

  # - Likelihood.
  with pyro.plate('data', size=N, subsample_size=10000) as ind:
    pyro.sample('y', dist.Bernoulli(logits=p.index_select(0, ind)), obs=y.index_select(0, ind))

Thanks for your help!