Right, its because my actual model is a bit more complicated:
class Seq2Hist(torch.nn.Module):
def __init__(self, D_in, D_out):
"""
In the constructor we instantiate two nn.Linear modules and assign them as
member variables.
"""
super(Seq2Hist, self).__init__()
self.num_words = D_out
def forward(self, x):
"""
In the forward function we accept a Tensor of input data and we must return
a Tensor of output data. We can use Modules defined in the constructor as
well as arbitrary operators on Tensors.
"""
out = torch.zeros(self.num_words, x.shape[1]).scatter_add(
0, x, torch.ones(x.shape)
).transpose(1, 0)
return out
def model(data=None, args=None, batch_size=None, annealing_factor=1):
with poutine.scale(None, annealing_factor):
# Globals.
topic_weights = pyro.sample(
"topic_weights", dist.Dirichlet(10 * torch.ones(args.num_topics))
)
with pyro.plate("topics", args.num_topics + 1):
topic_words = pyro.sample(
"topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words),
)
# Locals.
with pyro.plate("documents", data.size(1) if data is not None else args.num_docs) as ind:
with poutine.scale(None, annealing_factor):
doc_background_prob = pyro.sample("doc_background_prob",
dist.Beta(10 * torch.ones(1), 10 * torch.ones(1))
).unsqueeze(-1)
doc_topic = pyro.sample("doc_topic", dist.Categorical(topic_weights)) + 1
with pyro.plate("words", args.num_words_per_doc):
# The word_topics variable is marginalized out during inference,
# achieved by specifying infer={"enumerate": "parallel"} and using
# TraceEnum_ELBO for inference. Thus we can ignore this variable in
# the guide.
word_background_ind = pyro.sample(
"word_background_ind",
dist.Bernoulli(doc_background_prob.squeeze()),
infer={"enumerate": "parallel"}
)
topic_ind = (doc_topic * word_background_ind).type(torch.LongTensor)
data = pyro.sample(
"doc_words", dist.Categorical(Vindex(topic_words)[topic_ind]), obs=data
)
return topic_weights, topic_words, doc_topic, data
class MLP(nn.Module):
def __init__(self, args, eps=1):
super(MLP, self).__init__()
layer_sizes = (
[args.num_words]
# + [int(s) for s in args.layer_sizes.split("-")]
)
logging.info("Creating MLP with sizes {}".format(layer_sizes))
self.eps = eps
self.seq2hist = Seq2Hist(args.num_words_per_doc, args.num_words)
self.layers = []
# for in_size, out_size in zip(layer_sizes, layer_sizes[1:]):
# layer = nn.Linear(in_size, out_size)
# nn.init.xavier_uniform_(layer.weight)
# self.layers.append(layer)
self.prob_layer = nn.Linear(layer_sizes[-1], 2)
nn.init.xavier_uniform_(self.prob_layer.weight)
# self.prob_layer.weight.data.fill_(1.)
self.prob_scale_layer = nn.Linear(layer_sizes[-1], 1)
nn.init.xavier_uniform_(self.prob_scale_layer.weight)
# self.prob_scale_layer.weight.data.fill_(1.)
self.topic_layer = nn.Linear(layer_sizes[-1], args.num_topics)
nn.init.xavier_uniform_(self.topic_layer.weight)
self.topic_act = nn.Softmax(dim=-1)
self.anneal_floor = 0
# forward propagate input
def forward(self, X):
self.anneal_floor += 1
# input to first hidden layer
z = self.seq2hist(X)
# print(z.shape)
for layer in self.layers:
z = nn.ReLU()(layer(z))
background_prob = nn.Softmax(dim=-1)(self.prob_layer(z))
background_prob_scale = nn.ReLU()(self.prob_scale_layer(z)) + self.eps / self.anneal_floor
background_prob_prior = background_prob * background_prob_scale
topic_dist = self.topic_act(self.topic_layer(z))
return background_prob_prior, topic_dist
# @config_enumerate
def parametrized_guide(predictor, data, args, batch_size=None, annealing_factor=1):
# Use a conjugate guide for global variables.
topic_weights_posterior = pyro.param(
"topic_weights_posterior",
lambda: torch.ones(args.num_topics),
constraint=constraints.positive,
)
topic_words_posterior = pyro.param(
"topic_words_posterior",
lambda: torch.ones(args.num_topics + 1, args.num_words),
constraint=constraints.greater_than(1e-8),
)
with poutine.scale(None, annealing_factor):
pyro.sample("topic_weights", dist.Dirichlet(topic_weights_posterior))
with pyro.plate("topics", args.num_topics + 1):
pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))
# Use an amortized guide for local variables.
pyro.module("predictor", predictor)
with pyro.plate("documents", data.size(1)):
background_prob_posterior, doc_topic_posterior = predictor(data)
with poutine.scale(None, annealing_factor):
# print(background_prob_posterior)
pyro.sample("doc_background_prob",
dist.Beta(background_prob_posterior[:, 0],
background_prob_posterior[:, 1]))
# print(doc_topic_posterior)
pyro.sample("doc_topic", dist.Categorical(doc_topic_posterior),
infer={"enumerate": "parallel"})