I’m having trouble getting subsampling working for a model with nested plates. Model, MWE, and error below. m
, c_ij
, and s_ij
are all binary (but set as Categorical to allow for future expansion of the domains). Their CPDs (the p
variables) are all drawn from a currently uninformative Dirichlet prior. The observed variables are m
, c
, and s
(for machine, metric, sample).
Initially, if I leave out subsampling then SVI runs fine, and I’m able to recover the correct CPDs for p_m
and p_s
(based on training data I simulate; not shown here), although I can’t yet correctly recover p_c
, so I suspect something is off with indexing into that CPD on the c
sample line.
Additionally, and perhaps more simply, when I do try to use subsampling on the inner plate, I get a dimension mismatch at site c
error. Looking at the shape trace it seems the value at s
is not properly getting subsampled, but its distribution is. This is odd, because if I examine the shapes of the variables returned by m, s, c = model(S, R)
they all look to be correctly subsampled.
Any idea why this is happening? I’m hoping that fixing this subsampling error will in turn allow the model to properly recover p_c
. Thanks!
MWE:
import numpy as np
import torch
from torch.distributions import constraints
from pyro.ops.indexing import Vindex
import pyro.optim
import pyro.infer
import pyro.distributions as dist
R = 20
S = 100
subsample_size = 10
def model(S, R):
run_plate = pyro.plate('run_plate', R, dim=-1)
sample_plate = pyro.plate('sample_plate', S, dim=-2, subsample_size=subsample_size)
# global CPDs
p_m = pyro.sample('p_m', dist.Dirichlet(torch.ones(2)))
p_s = pyro.sample('p_s', dist.Dirichlet(torch.ones(2)))
with pyro.plate_stack("c_plate", (2,2)):
p_c = pyro.sample('p_c', dist.Dirichlet(torch.ones(2,2,2)))
# data
with run_plate:
# machine
m = pyro.sample('m', dist.Categorical(p_m))
# sample and metric
with sample_plate:
s = pyro.sample('s', dist.Categorical(p_s))
c = pyro.sample('c', dist.Categorical(Vindex(p_c)[m, s, :])) # (machine, sample, metric)
return m, s, c
def guide(S, R):
# machine
m_prior = pyro.param('m_post', machine_prior_t, constraint=constraints.positive)
p_m = pyro.sample('p_m', dist.Dirichlet(m_prior))
# sample
s_prior = pyro.param('s_post', sample_prior_t, constraint=constraints.positive)
p_s = pyro.sample('p_s', dist.Dirichlet(s_prior))
# metric
c_prior = pyro.param('c_post', metric_prior_t, constraint=constraints.positive) # (machine, sample, metric)
with pyro.plate_stack("c_plate", (2,2)):
p_c = pyro.sample('p_c', dist.Dirichlet(c_prior))
# condition model on dummy MWE data
observations = {'m': torch.ones(R, dtype=torch.long),
's': torch.ones((S, R), dtype=torch.long),
'c': torch.ones((S, R), dtype=torch.long),
}
cond_model = pyro.condition(model, data=observations)
# run svi
svi = pyro.infer.SVI(cond_model, guide, pyro.optim.Adam({"lr": .0001}), loss=pyro.infer.Trace_ELBO())
for step in range(10000):
svi.step(S, R)
Error:
ValueError: Shape mismatch inside plate('sample_plate') at site c dim -2, 10 vs 100
Trace Shapes:
Param Sites:
Sample Sites:
run_plate dist |
value 20 |
sample_plate dist |
value 10 |
p_m dist | 2
value | 2
p_s dist | 2
value | 2
c_plate_0 dist |
value 2 |
c_plate_1 dist |
value 2 |
p_c dist 2 2 | 2
value 2 2 | 2
m dist 20 |
value 20 |
s dist 10 20 |
value 100 20 |