I am trying to implement a simple LDA model using dummy data. I have two problems:
- I cannot seem to get the example below to run
- I do not know whether I need to use
.independent()
within my plates or not.
Any help would be appreciated!
Setup and generate dataset:
import torch
import torch.distributions.constraints as constraints
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam
num_docs = 7
num_topics = 3
num_words = 5
# hyperparams
alpha = torch.zeros([num_docs, num_topics]) + 0.5
gamma = torch.zeros([num_topics, num_words]) + 0.01
# priors
theta = dist.Dirichlet(alpha).sample()
beta = dist.Dirichlet(gamma).sample()
# topics
z = [torch.zeros(num_words, dtype = torch.long) for i in range(num_docs)]
# word counts
data = [torch.zeros(num_words) for i in range(num_docs)]
for d in range(num_docs):
for n in range(num_words):
z[d][n] = dist.Categorical(theta[d, :]).sample()
data[d][n] = dist.Categorical(beta[z[d][n], :]).sample()
Model:
# define the model (i.e. joint distribution p(beta|gamma)p(theta|alpha)p(z|theta)p(w|z, beta))
def model(data):
# setup hyperparameters for the priors p(beta) p(theta)
alpha = pyro.sample('alpha', dist.Gamma(torch.ones(num_topics) / num_topics, 1.0))
gamma = torch.ones(num_words) / num_words
# loop over topics
with pyro.plate('topics', num_topics) as k:
# sample from prior p(gamma)
beta = pyro.sample('beta_{0}'.format(k), dist.Dirichlet(gamma))
# loop over documents
with pyro.plate('documents', num_docs) as d:
# sample from prior p(alpha)
theta = pyro.sample('theta_{0}'.format(d), dist.Dirichlet(alpha))
# loop over words in a given document
with pyro.plate('words', num_words) as n:
# sample from likelihood p(theta|alpha)
z = pyro.sample('z_{0}_ {1}'.format(d, n), dist.Categorical(theta))
# sample from likelihood p(w|z, beta) and score against actual observed words
w = pyro.sample('w_{0}_{1}'.format(d, n), dist.Categorical(beta[z]), obs = data[d])
Guide:
# define the guide (i.e. variational distribution q(theta|eta)q(z|phi)q(beta|lambda)
def guide(data):
eta = pyro.param('eta', (torch.ones(num_topics) + 0.5) / num_topics, constraint = constraints.positive)
lamda = pyro.param('lambda', (torch.ones(num_words) + 0.25) / num_words, constraint = constraints.positive)
with pyro.plate('topics', num_topics) as k:
beta_q = pyro.sample('beta_{0}'.format(k), dist.Dirichlet(lamda))
with pyro.plate('documents', num_docs) as d:
theta_q = pyro.sample('theta_{0}'.format(d), dist.Dirichlet(eta))
with pyro.plate('words', num_words) as n:
phi = pyro.param('phi_{0}'.format(n), torch.randn([num_docs, num_topics]).exp(), constraint = constraints.simplex)
z_q = pyro.sample('z_{0}_{1}'.format(d, n), dist.Categorical(phi))
Variational inference:
pyro.clear_param_store()
adam_params = {'lr': 0.01}
optimizer = Adam(adam_params)
svi = SVI(model, config_enumerate(guide, 'parallel'), optimiser, loss = TraceEnum_ELBO(max_iarange_nesting = 2))
for _ in range(10):
loss = svi.step(data)
Error:
<ipython-input-596-8f2d81fbf6cf> in model(data)
11 with pyro.plate('words', num_words) as n:
12 z = pyro.sample('z_{0}_{1}'.format(d, n), dist.Categorical(theta))
---> 13 w = pyro.sample('w_{0}_{1}'.format(d, n), dist.Categorical(beta[z]), obs = data[d])
TypeError: only integer tensors of a single element can be converted to an index