I am trying to allow variable length sequences in a language model. I think the right way to do this is using masks, but I currently get an error. My code looks like this:
class Seq2Hist(torch.nn.Module):
def __init__(self, D_in, D_out):
"""
In the constructor we instantiate two nn.Linear modules and assign them as
member variables.
"""
super(Seq2Hist, self).__init__()
self.num_words = D_out
def forward(self, x):
"""
In the forward function we accept a Tensor of input data and we must return
a Tensor of output data. We can use Modules defined in the constructor as
well as arbitrary operators on Tensors.
"""
out = torch.zeros(self.num_words + 1, x.shape[1]).scatter_add(
0, x, torch.ones(x.shape)
).transpose(1, 0)
return out[:, :-1]
class MLP(nn.Module):
def __init__(self, args, eps=1):
super(MLP, self).__init__()
self.eps = eps
self.seq2hist = Seq2Hist(args.num_words_per_doc, args.num_words)
self.prob_layer = nn.Linear(args.num_words, 2)
nn.init.xavier_uniform_(self.prob_layer.weight)
self.prob_scale_layer = nn.Linear(args.num_words, 1)
nn.init.xavier_uniform_(self.prob_scale_layer.weight)
self.anneal_floor = 0
# forward propagate input
def forward(self, X):
self.anneal_floor += 1
z = self.seq2hist(X)
background_prob = nn.Softmax(dim=-1)(self.prob_layer(z))
background_prob_scale = nn.ReLU()(self.prob_scale_layer(z)) + self.eps / self.anneal_floor
background_prob_prior = background_prob * background_prob_scale
return background_prob_prior
def model(data=None, mask=None, args=None, batch_size=None, annealing_factor=1):
with poutine.scale(None, annealing_factor):
# Globals.
topic_weights = pyro.sample(
"topic_weights", dist.Dirichlet(10 * torch.ones(args.num_topics))
)
with pyro.plate("topics", args.num_topics + 1):
topic_words = pyro.sample(
"topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words),
)
# Locals.
with pyro.plate("documents", data.size(1) if data is not None else args.num_docs):
with poutine.scale(None, annealing_factor):
doc_background_prob = pyro.sample("doc_background_prob",
dist.Beta(10 * torch.ones(1), 10 * torch.ones(1))
).unsqueeze(-1)
doc_topic = pyro.sample("doc_topic", dist.Categorical(topic_weights),
infer={"enumerate": "parallel"}) + 1
with pyro.plate("words", args.num_words_per_doc):
# The word_topics variable is marginalized out during inference,
# achieved by specifying infer={"enumerate": "parallel"} and using
# TraceEnum_ELBO for inference. Thus we can ignore this variable in
# the guide.
with poutine.mask(None, mask):
word_background_ind = pyro.sample(
"word_background_ind",
dist.Bernoulli(doc_background_prob.squeeze()),
infer={"enumerate": "parallel"}
)
topic_ind = (doc_topic * word_background_ind).type(torch.LongTensor)
data = pyro.sample(
"doc_words", dist.Categorical(Vindex(topic_words)[topic_ind]),
obs=data, obs_mask=mask
)
return topic_weights, topic_words, doc_topic, data
def parametrized_guide(predictor, data, mask, args, batch_size=None, annealing_factor=1):
# Use a conjugate guide for global variables.
topic_weights_posterior = pyro.param(
"topic_weights_posterior",
lambda: torch.ones(args.num_topics),
constraint=constraints.positive,
)
topic_words_posterior = pyro.param(
"topic_words_posterior",
lambda: torch.ones(args.num_topics + 1, args.num_words),
constraint=constraints.greater_than(1e-8),
)
with poutine.scale(None, annealing_factor):
pyro.sample("topic_weights", dist.Dirichlet(topic_weights_posterior))
with pyro.plate("topics", args.num_topics + 1):
pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))
# Use an amortized guide for continuous local variables.
pyro.module("predictor", predictor)
with pyro.plate("documents", data.size(1)):
background_prob_posterior = predictor(data)
with poutine.scale(None, annealing_factor):
pyro.sample("doc_background_prob",
dist.Beta(background_prob_posterior[:, 0],
background_prob_posterior[:, 1]))
The sequences are padded were originally padded with a -1 but this gave an out of support error. So I switched to padding the sequences with a 0, but now the topics all put high probability on 0, so clearly I’m not masking correctly here, or maybe masking isn’t the right approach…
Any assistance would be helpful!