Hi everyone,
I’m currently trying to use the following pattern for SVI:
infer.SVI(model, autoguide.AutoDelta(model), optimizer, loss=infer.Trace_ELBO())
My model takes in mini-batches of size B (say 64) out of total data size M (say 100). This means that the last mini-batch is of size 36 and this perhaps does not go well with AutoDelta
with an error like this.
ValueError: Shape mismatch inside plate('X_plate') at site Z_1 dim -1, 64 vs 36
Trace Shapes:
Param Sites:
dmm$$$h_0 1 128
dmm$$$z_0 1 128
dmm$$$transition.gru.weight_ih 384 128
dmm$$$transition.gru.weight_hh 384 128
dmm$$$transition.gru.bias_ih 384
dmm$$$transition.gru.bias_hh 384
dmm$$$transition.mu_head.0.weight 128 128
dmm$$$transition.mu_head.0.bias 128
dmm$$$transition.mu_head.2.weight 128 128
dmm$$$transition.mu_head.2.bias 128
dmm$$$transition.log_var_head.0.weight 128 128
dmm$$$transition.log_var_head.0.bias 128
dmm$$$transition.log_var_head.2.weight 128 128
dmm$$$transition.log_var_head.2.bias 128
dmm$$$emitter.mu_head.0.weight 128 128
dmm$$$emitter.mu_head.0.bias 128
dmm$$$emitter.mu_head.2.weight 13 128
dmm$$$emitter.mu_head.2.bias 13
dmm$$$emitter.log_var_head.0.weight 128 128
dmm$$$emitter.log_var_head.0.bias 128
dmm$$$emitter.log_var_head.2.weight 13 128
dmm$$$emitter.log_var_head.2.bias 13
dmm$$$missing.net.0.weight 128 128
dmm$$$missing.net.0.bias 128
dmm$$$missing.net.2.weight 12 128
dmm$$$missing.net.2.bias 12
dmm$$$pheno.net.0.weight 128 128
dmm$$$pheno.net.0.bias 128
dmm$$$pheno.net.2.weight 25 128
dmm$$$pheno.net.2.bias 25
Sample Sites:
X_plate dist |
value 64 |
I don’t find this problem if I use an empty guide. Any ideas what is happening here?
My model looks like this:
with pyro.plate('X_plate', B, device=X.device):
for t in range(T_max):
with poutine.mask(mask=(t < L)):
h_t, z_mu, z_log_var = self.transition(z_prev, h_prev)
z_dist = dist.Normal(z_mu, z_log_var.exp()).to_event(1)
z_t = pyro.sample('Z_{}'.format(t + 1), z_dist)
df = pyro.param('df', torch.tensor(2.0, device=X.device),
constraint=constraints.positive)
x_mu, x_log_var = self.emitter(z_t)
x_dist = dist.StudentT(df, x_mu, x_log_var.exp()).to_event(1)
pyro.sample('X_{}'.format(t + 1), x_dist, obs=X[t, :, :13])
missingness_p = self.missing(z_t)
m_dist = dist.Bernoulli(missingness_p).to_event(1)
pyro.sample('I_{}'.format(t + 1), m_dist, obs=X[t, :, 13:13 + 12])
z_pheno[t == L] = z_t[t == L]
h_prev = h_t
z_prev = z_t
with pyro.plate('P_plate', 25, device=pheno.device):
with pyro.plate('PX_plate', B, device=X.device):
pheno_p = self.pheno(z_pheno)
pheno_dist = dist.Bernoulli(pheno_p)
pyro.sample('P', pheno_dist, obs=pheno)
I’ve verified the batch shapes look sane.
B: 36
z_mu.shape: torch.Size([36, 128])
z_dist.batch_shape: torch.Size([36])