I’m trying to implement subsampling the ELBO computation for the observation likelihood in the last plate in my model below, but it keeps giving me an error of ‘ValueError: Shape mismatch inside plate(‘data’) at site y dim -1, 10000 vs 232448’.
I looked into how to subsample and it looks like I’m following what was prescribed… I’m assuming pyro.plate finds the indices to subsample and makes the shapes appropriate when I pass index_select(), so I’m not sure what the issue is? The model runs fine without subsampling, so I’m pretty sure it’s not a model definition issue. Any clues on what could be wrong in the below?
def horseshoe_classification2(data, y):
N, feature_dim = data.shape
# - Horseshoe prior.
# Global shrinkage parameter (tau) multiplies a coefficient's local shrinkage parameter (lambda_i) to compute the
# standard deviation of the Gaussian that generates that coefficient (beta_i).
tau = pyro.sample('tau', dist.HalfCauchy(torch.ones(1)))
# Local shrinkage parameter (lambda_i) represents the local component of the standard deviation of the Gaussian that
# generates that coefficient (beta_i).
with pyro.plate('lambdas_plate', feature_dim):
lambdas = pyro.sample('lambdas', dist.HalfCauchy(torch.ones(feature_dim)))
# The horseshoe prior assumes each coefficient (beta_i) is conditionally independent.
with pyro.plate('beta_plate', feature_dim):
# Reparameterize to improve posterior geometry (not specific to horseshoe regression).
with poutine.reparam(config={'betas': LocScaleReparam()}):
betas = pyro.sample('betas', dist.Normal(0, tau * lambdas))
# Kappa_i is roughly interpreted as the amount of weight that the posterior mean for beta_i places on 0 after the data
# has been observed (this interpretation is primarily for regression when sigma^2 and tau^2 both equal 1).
pyro.deterministic('kappas', 1/(1 + lambdas**2))
# - Intercept prior.
intercept = pyro.sample('intercept', dist.Normal(0, 10))
# - Linear model.
p = data @ betas + intercept
# - Likelihood.
with pyro.plate('data', size=N, subsample_size=10000) as ind:
pyro.sample('y', dist.Bernoulli(logits=p), obs=y.index_select(0, ind))
# Set up training objects.
guide = AutoDiagonalNormal(horseshoe_classification2)
svi = SVI(horseshoe_classification2, guide, optim.Adam({"lr": 0.01}), loss=Trace_ELBO())
# Train model.
num_iters = 5000
for i in range(num_iters):
elbo = svi.step(X_train, y_train)
if i % 500 == 0:
print("Elbo loss: {}".format(elbo))