Please find the warning, model (this is a simple HMM) and guide code below:
Warning:
UserWarning: Found vars in model but not guide: {'z_37', 'z_35', 'z_24', 'z_22', 'z_16', 'z_14', 'z_25', 'z_39', 'z_12', 'z_29', 'z_31', 'z_32', 'z_36', 'z_18', 'z_30', 'z_10', 'z_0', 'z_20', 'z_5', 'z_9', 'z_40', 'z_23', 'z_27', 'z_28', 'z_11', 'z_3', 'z_8', 'z_33', 'z_2', 'z_1', 'z_34', 'z_6', 'z_4', 'z_17', 'z_19', 'z_15', 'z_21', 'z_26', 'z_13', 'z_41', 'z_38', 'z_7'}
warnings.warn("Found vars in model but not guide: {}".format(model_vars - guide_vars - enum_vars))
Model:
class DMM(nn.Module):
def __init__(self,
nannotators=47,
nlabels=5,
vocabulary_size=13476,
use_cuda=False):
super(DMM, self).__init__()
self.use_cuda = use_cuda
self.nannotators = nannotators
self.nlabels = nlabels
self.vocabulary_size = vocabulary_size
if self.use_cuda:
self.cuda()
def model(self, mini_batch_tokens):
pyro.module("dmm", self)
T_max = mini_batch_tokens.size(1)
ninstances = mini_batch_tokens.size(0)
valid_mini_batch_tokens = mini_batch_tokens.clone()
valid_mini_batch_tokens_mask = valid_mini_batch_tokens != -1
valid_mini_batch_tokens[~valid_mini_batch_tokens_mask] = 0
transition_0 = pyro.sample("transition_0", dist.Dirichlet(0.5 * torch.ones(self.nlabels)))
with pyro.plate("labels", self.nlabels):
transition = pyro.sample("transition", dist.Dirichlet(0.5 * torch.ones(self.nlabels)))
token_emission = pyro.sample("token_emission", dist.Dirichlet(0.5 * torch.ones(self.vocabulary_size)))
with pyro.plate("data", ninstances):
for t in range(0, T_max):
token_mask = (mini_batch_tokens[:, t] != -1)
transition_param = transition[z_prev] if t > 0 else transition_0.expand(ninstances, self.nlabels)
z_t = pyro.sample("z_{}".format(t),
dist.Categorical(transition_param).mask(token_mask),
infer={"enumerate":"sequential"})
pyro.sample("obs_x_{}".format(t),
dist.Categorical(token_emission[z_t]).mask(token_mask),
obs=valid_mini_batch_tokens[:,t])
z_prev = z_t
Guide:
elbo = TraceEnum_ELBO(max_plate_nesting=1)
guide = AutoDelta(poutine.block(dmm.model, expose=["transition_0",
"transition",
"token_emission"]))
loss_basic = SVI(dmm.model, guide, optim, loss=elbo)