Hi, I’m new to pyro so many thanks for comments.
This is an attempt to understand effect handlers via a simple modification of the amortized lda tutorial to ignore part of the data tensor.
Suppose each of the 1000 docs of length 64 is actually padded and we have:
doclengths = torch.randint(low=20,high=40,size=(1000,1))
Then update the model so that poutine.mask
ignores doclengths[ind]+
for each doc:
def model(data=None, doclengths=doclengths,args=None, batch_size=None):
# Globals.
with pyro.plate("topics", args.num_topics):
topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
topic_words = pyro.sample("topic_words",
dist.Dirichlet(torch.ones(args.num_words) / args.num_words))
# Locals.
with pyro.plate("documents", args.num_docs) as ind:
if data is not None:
with pyro.util.ignore_jit_warnings():
assert data.shape == (args.num_words_per_doc, args.num_docs)
data = data[:, ind]
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
x = ind.view(ind.shape[0],1)
y = torch.arange(args.num_words_per_doc).unsqueeze(-2)
_,b=torch.broadcast_tensors(x,y)
with pyro.plate("words", args.num_words_per_doc):
# The word_topics variable is marginalized out during inference,
# achieved by specifying infer={"enumerate": "parallel"} and using
# TraceEnum_ELBO for inference. Thus we can ignore this variable in
# the guide.
with poutine.mask(mask=(b < doclengths[ind]).unsqueeze(-1)):
# with poutine.mask(mask=m < doclengths):
word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),
infer={"enumerate": "parallel"})
data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]),
obs=data)
else:
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
with pyro.plate("words", args.num_words_per_doc):
# The word_topics variable is marginalized out during inference,
# achieved by specifying infer={"enumerate": "parallel"} and using
# TraceEnum_ELBO for inference. Thus we can ignore this variable in
# the guide.
word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),
infer={"enumerate": "parallel"})
data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]),obs=data)
return topic_weights, topic_words, data
This throws an error:
Exception has occurred: RuntimeError
The size of tensor a (32) must match the size of tensor b (8) at non-singleton dimension 0
File "/home/au/code/pyro_examples/lda_amortized_ragged/__main__.py", line 149, in main
loss = svi.step(data, args=args, batch_size=args.batch_size)
File "/home/au/code/pyro_examples/lda_amortized_ragged/__main__.py", line 169, in <module>
main(args)
Regarding:
loss = svi.step(data, args=args, batch_size=args.batch_size)