Why I set the latent variable infer={'enumerate':...}, but still get ''found vars in model but no guide'' warning?


#1

Please find the warning, model (this is a simple HMM) and guide code below:

Warning:

UserWarning: Found vars in model but not guide: {'z_37', 'z_35', 'z_24', 'z_22', 'z_16', 'z_14', 'z_25', 'z_39', 'z_12', 'z_29', 'z_31', 'z_32', 'z_36', 'z_18', 'z_30', 'z_10', 'z_0', 'z_20', 'z_5', 'z_9', 'z_40', 'z_23', 'z_27', 'z_28', 'z_11', 'z_3', 'z_8', 'z_33', 'z_2', 'z_1', 'z_34', 'z_6', 'z_4', 'z_17', 'z_19', 'z_15', 'z_21', 'z_26', 'z_13', 'z_41', 'z_38', 'z_7'}
  warnings.warn("Found vars in model but not guide: {}".format(model_vars - guide_vars - enum_vars))

Model:

class DMM(nn.Module):
    def __init__(self,
                 nannotators=47,
                 nlabels=5,
                 vocabulary_size=13476,
                 use_cuda=False):
        super(DMM, self).__init__()
        self.use_cuda = use_cuda

        self.nannotators = nannotators
        self.nlabels = nlabels
        self.vocabulary_size = vocabulary_size

        if self.use_cuda:
            self.cuda()

    def model(self, mini_batch_tokens):
        pyro.module("dmm", self)

        T_max = mini_batch_tokens.size(1)
        ninstances = mini_batch_tokens.size(0)

        valid_mini_batch_tokens = mini_batch_tokens.clone()
        valid_mini_batch_tokens_mask = valid_mini_batch_tokens != -1
        valid_mini_batch_tokens[~valid_mini_batch_tokens_mask] = 0

        transition_0 = pyro.sample("transition_0", dist.Dirichlet(0.5 * torch.ones(self.nlabels)))
        with pyro.plate("labels", self.nlabels):
            transition = pyro.sample("transition", dist.Dirichlet(0.5 * torch.ones(self.nlabels)))
            token_emission = pyro.sample("token_emission", dist.Dirichlet(0.5 * torch.ones(self.vocabulary_size)))

        with pyro.plate("data", ninstances):
            for t in range(0, T_max):
                token_mask = (mini_batch_tokens[:, t] != -1)
                transition_param = transition[z_prev] if t > 0 else transition_0.expand(ninstances, self.nlabels)
                z_t = pyro.sample("z_{}".format(t),
                                  dist.Categorical(transition_param).mask(token_mask),
                                  infer={"enumerate":"sequential"})

                pyro.sample("obs_x_{}".format(t),
                            dist.Categorical(token_emission[z_t]).mask(token_mask),
                            obs=valid_mini_batch_tokens[:,t])

                z_prev = z_t

Guide:

elbo = TraceEnum_ELBO(max_plate_nesting=1)
guide = AutoDelta(poutine.block(dmm.model, expose=["transition_0",
                                                       "transition",
                                                       "token_emission"]))

loss_basic = SVI(dmm.model, guide, optim, loss=elbo)