Dimension mismatch using AutoDelta and Mini-Batch

Hi everyone,

I’m currently trying to use the following pattern for SVI:

infer.SVI(model, autoguide.AutoDelta(model), optimizer, loss=infer.Trace_ELBO())

My model takes in mini-batches of size B (say 64) out of total data size M (say 100). This means that the last mini-batch is of size 36 and this perhaps does not go well with AutoDelta with an error like this.

ValueError: Shape mismatch inside plate('X_plate') at site Z_1 dim -1, 64 vs 36
                         Trace Shapes:        
                          Param Sites:        
                             dmm$$$h_0   1 128
                             dmm$$$z_0   1 128
        dmm$$$transition.gru.weight_ih 384 128
        dmm$$$transition.gru.weight_hh 384 128
          dmm$$$transition.gru.bias_ih     384
          dmm$$$transition.gru.bias_hh     384
     dmm$$$transition.mu_head.0.weight 128 128
       dmm$$$transition.mu_head.0.bias     128
     dmm$$$transition.mu_head.2.weight 128 128
       dmm$$$transition.mu_head.2.bias     128
dmm$$$transition.log_var_head.0.weight 128 128
  dmm$$$transition.log_var_head.0.bias     128
dmm$$$transition.log_var_head.2.weight 128 128
  dmm$$$transition.log_var_head.2.bias     128
        dmm$$$emitter.mu_head.0.weight 128 128
          dmm$$$emitter.mu_head.0.bias     128
        dmm$$$emitter.mu_head.2.weight  13 128
          dmm$$$emitter.mu_head.2.bias      13
   dmm$$$emitter.log_var_head.0.weight 128 128
     dmm$$$emitter.log_var_head.0.bias     128
   dmm$$$emitter.log_var_head.2.weight  13 128
     dmm$$$emitter.log_var_head.2.bias      13
            dmm$$$missing.net.0.weight 128 128
              dmm$$$missing.net.0.bias     128
            dmm$$$missing.net.2.weight  12 128
              dmm$$$missing.net.2.bias      12
              dmm$$$pheno.net.0.weight 128 128
                dmm$$$pheno.net.0.bias     128
              dmm$$$pheno.net.2.weight  25 128
                dmm$$$pheno.net.2.bias      25
                         Sample Sites:        
                          X_plate dist       |
                                 value  64   |

I don’t find this problem if I use an empty guide. Any ideas what is happening here?

My model looks like this:

    with pyro.plate('X_plate', B, device=X.device):
      for t in range(T_max):
        with poutine.mask(mask=(t < L)):
          h_t, z_mu, z_log_var = self.transition(z_prev, h_prev)
          z_dist = dist.Normal(z_mu, z_log_var.exp()).to_event(1)

          z_t = pyro.sample('Z_{}'.format(t + 1), z_dist)

          df = pyro.param('df', torch.tensor(2.0, device=X.device),
                          constraint=constraints.positive)
          x_mu, x_log_var = self.emitter(z_t)
          x_dist = dist.StudentT(df, x_mu, x_log_var.exp()).to_event(1)

          pyro.sample('X_{}'.format(t + 1), x_dist, obs=X[t, :, :13])

          missingness_p = self.missing(z_t)
          m_dist = dist.Bernoulli(missingness_p).to_event(1)
          pyro.sample('I_{}'.format(t + 1), m_dist, obs=X[t, :, 13:13 + 12])

        z_pheno[t == L] = z_t[t == L]

        h_prev = h_t
        z_prev = z_t

    with pyro.plate('P_plate', 25, device=pheno.device):
      with pyro.plate('PX_plate', B, device=X.device):
        pheno_p = self.pheno(z_pheno)
        pheno_dist = dist.Bernoulli(pheno_p)

        pyro.sample('P', pheno_dist, obs=pheno)

I’ve verified the batch shapes look sane.

B: 36
z_mu.shape: torch.Size([36, 128])
z_dist.batch_shape: torch.Size([36])

Pyro’s autoguides currently do not work with minibatched sample statements (but observe statements pyro.sample(..., obs=...) are fine).

One thing you could try is to:

  • use TraceEnum_ELBO to exactly marginalize out local discrete sample sites; and
  • ensure all continuous sample sites are global, e.g. a single shared df in your example.

Another thing you could try is to amortize inference instead of using an AutoDelta.

Isn’t df a param here and not a sample? I have moved it outside the plate now because I wanted this to be a global parameter. I was under the impression that it always loads from the parameter store (except when undefined).

If I understand correctly, you are suggesting a handwritten guide. Right?

Also, how is using an empty function guide different from using AutoDelta. I discovered from the forum that both lead to a MAP estimate. Is that the correct understanding?

Yes, you are correct; I was mistaken.

you are suggesting a handwritten guide. Right?

Yes, a handwritten guide.

how is using an empty function guide different from using AutoDelta

An empty guide will only work when there are no latent variables (after possibly marginalizing out discrete latent variables using TraceEnum_ELBO). An AutoDelta guide provides point estimates of all latent variables. If you have no latent variables then AutoDelta is effectively an empty guide.

1 Like