Simple LDA not working

I am trying to implement a simple LDA model using dummy data. I have two problems:

  1. I cannot seem to get the example below to run
  2. I do not know whether I need to use .independent() within my plates or not.

Any help would be appreciated!


Setup and generate dataset:

import torch
import torch.distributions.constraints as constraints

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam

num_docs = 7
num_topics = 3
num_words = 5

# hyperparams
alpha = torch.zeros([num_docs, num_topics]) + 0.5
gamma = torch.zeros([num_topics, num_words]) + 0.01

# priors
theta = dist.Dirichlet(alpha).sample()
beta = dist.Dirichlet(gamma).sample()

# topics
z = [torch.zeros(num_words, dtype = torch.long) for i in range(num_docs)]
# word counts
data = [torch.zeros(num_words) for i in range(num_docs)]
for d in range(num_docs):
    for n in range(num_words):
        z[d][n] = dist.Categorical(theta[d, :]).sample()
        data[d][n] = dist.Categorical(beta[z[d][n], :]).sample()

Model:

# define the model (i.e. joint distribution p(beta|gamma)p(theta|alpha)p(z|theta)p(w|z, beta))
def model(data):
    # setup hyperparameters for the priors p(beta) p(theta)
    alpha =  pyro.sample('alpha', dist.Gamma(torch.ones(num_topics) / num_topics, 1.0))
    gamma = torch.ones(num_words) / num_words
    # loop over topics
    with pyro.plate('topics', num_topics) as k:
        # sample from prior p(gamma)
        beta = pyro.sample('beta_{0}'.format(k), dist.Dirichlet(gamma))
    # loop over documents
    with pyro.plate('documents', num_docs) as d:
        # sample from prior p(alpha)
        theta = pyro.sample('theta_{0}'.format(d), dist.Dirichlet(alpha))
        # loop over words in a given document
        with pyro.plate('words', num_words) as n:
            # sample from likelihood p(theta|alpha)
            z = pyro.sample('z_{0}_ {1}'.format(d, n), dist.Categorical(theta))
            # sample from likelihood p(w|z, beta) and score against actual observed words
            w = pyro.sample('w_{0}_{1}'.format(d, n), dist.Categorical(beta[z]), obs = data[d])

Guide:

# define the guide (i.e. variational distribution q(theta|eta)q(z|phi)q(beta|lambda)
def guide(data):
    eta = pyro.param('eta', (torch.ones(num_topics) + 0.5) / num_topics, constraint = constraints.positive)
    lamda = pyro.param('lambda', (torch.ones(num_words) + 0.25) / num_words, constraint = constraints.positive)
    
    with pyro.plate('topics', num_topics) as k:
        beta_q = pyro.sample('beta_{0}'.format(k), dist.Dirichlet(lamda))
  
    with pyro.plate('documents', num_docs) as d:
        theta_q = pyro.sample('theta_{0}'.format(d), dist.Dirichlet(eta))
        
        with pyro.plate('words', num_words) as n:
            phi = pyro.param('phi_{0}'.format(n), torch.randn([num_docs, num_topics]).exp(), constraint = constraints.simplex)
       
            z_q = pyro.sample('z_{0}_{1}'.format(d, n), dist.Categorical(phi))

Variational inference:

pyro.clear_param_store()

adam_params = {'lr': 0.01}
optimizer = Adam(adam_params)
svi = SVI(model, config_enumerate(guide, 'parallel'), optimiser, loss = TraceEnum_ELBO(max_iarange_nesting = 2))

for _ in range(10):
    loss = svi.step(data)

Error:

<ipython-input-596-8f2d81fbf6cf> in model(data)
     11         with pyro.plate('words', num_words) as n:
     12             z = pyro.sample('z_{0}_{1}'.format(d, n), dist.Categorical(theta))
---> 13             w = pyro.sample('w_{0}_{1}'.format(d, n), dist.Categorical(beta[z]), obs = data[d])

TypeError: only integer tensors of a single element can be converted to an index

If I make corrections to the setup and dataset:

z = [torch.zeros(num_words, dtype = torch.long) for i in range(num_docs)]
data = [torch.zeros(num_words) for i in range(num_docs)]
for d in range(num_docs):
    for n in range(num_words):
        z[d][n] = dist.Categorical(theta[d, :]).sample()
        data[d][n] = dist.Categorical(beta[z[d][n], :]).sample()
data = torch.stack(data)

Corrections to model:

def model(data):
    gamma = torch.ones(num_words) / num_words
    with pyro.plate('topics', num_topics) as k:
        beta = pyro.sample('beta', dist.Dirichlet(gamma))
        assert beta.shape == (num_topics, num_words)

    alpha = torch.ones(num_topics) / num_topics
    with pyro.plate('documents', num_docs) as d:
        theta = pyro.sample('theta',dist.Dirichlet(alpha))
        assert theta.shape == (num_docs, num_topics)
        
        with pyro.plate('words', num_words) as n:
            z = pyro.sample('z', dist.Categorical(theta))
            #assert z.shape == (num_words, num_docs)
            w = pyro.sample('w', dist.Categorical(beta[z]), obs = data[d, :])
            assert w.shape == (num_docs, num_words)

Corrections to guide:

def guide(data):
    with pyro.plate('topics', num_topics) as k:
        gamma_q = pyro.param('gamma_q', torch.ones(num_words), constraint = constraints.positive)
        beta_q = pyro.sample('beta', dist.Dirichlet(gamma_q))
        assert beta_q.shape == (num_topics, num_words)

    with pyro.plate('documents', num_docs) as d:
        alpha_q = pyro.param('alpha_q', torch.ones(num_topics), constraint =constraints.positive)
        theta_q = pyro.sample('theta', dist.Dirichlet(alpha_q))
        assert theta_q.shape == (num_docs, num_topics)

        with pyro.plate('words', num_words) as n:
            phi = pyro.param('phi', torch.randn([num_docs, num_topics]).exp(), constraint = constraints.simplex)
            z_q = pyro.sample('z', dist.Categorical(phi)) 
            #assert z_q.shape == (num_words, num_docs) 

Corrections to variational inference:

pyro.enable_validation(True)
pyro.clear_param_store()

adam_params = {'lr': 0.01}
optimizer = Adam(adam_params)
svi = SVI(model, config_enumerate(guide, 'parallel'), optimizer, loss = TraceEnum_ELBO(max_iarange_nesting = 2))

for _ in range(10):
    loss = svi.step(data)

Then I get the following error

ValueError: Error while computing log_prob at site 'w':
Value is not broadcastable with batch_shape+event_shape: torch.Size([7, 5]) vs torch.Size([3, 5, 7]).
Trace Shapes:          
 Param Sites:          
Sample Sites:          
    beta dist     3 | 5
        value     3 | 5
     log_prob     3 |  
   theta dist     7 | 3
        value     7 | 3
     log_prob     7 |  
       z dist   5 7 |  
        value 3 1 1 |  
     log_prob 3 5 7 |  
       w dist 3 5 7 |  
        value   7 5 |  

I have checked the shapes of my tenors, but cannot see where the error is occurring:

model
----------
alpha [3]
gamma [5]
beta  [3, 5]
theta [7, 3]
z     [5, 7]
w     [7, 5]

guide
----------
alpha  [3]
gamma  [5]
beta   [3, 5]
theta  [7, 3]
phi    [7, 3]
z      [5, 7]

Could it be something to do with assert z.shape == (num_words, num_docs) is false? I assumed that z.shape = (num_words, num_docs) = (5, 7) but instead it is (3, 1, 1)?

I was told that all samples are either in a plate or have a .toevent(…) if they are not scalars.

@erlebach I have moved all of my sample statements inside of plates, but it still has not fixed the error.

I get the following error from the above corrected code:

ValueError: Error while computing log_prob at site 'w':
Value is not broadcastable with batch_shape+event_shape: torch.Size([7, 5]) vs torch.Size([3, 5, 7]).
Trace Shapes:          
 Param Sites:          
Sample Sites:          
    beta dist     3 | 5
        value     3 | 5
     log_prob     3 |  
   theta dist     7 | 3
        value     7 | 3
     log_prob     7 |  
       z dist   5 7 |  
        value 3 1 1 |  
     log_prob 3 5 7 |  
       w dist 3 5 7 |  
        value   7 5 |  

Can you please email me the full code at gordon.erlebach@gmail.com? I will take a look I still have not gotten any LDA to work with SVI. I would like to try yours. What data did you use?
Thanks.

The error said that the shape at w site is wrong. I guess (given my limited knowledge of LDA) that

should be

# here your code assumed that vocal_size = num_words
word_distribution = beta[z]  # shape: num_words x num_docs x vocal_size
data = data.t()  # shape: num_words x num_docs, which contains numbers from 0 to vocal_size-1
w = pyro.sample('w', dist.Categorical(word_distribution), obs=data)
assert w.shape == (num_words, num_docs)
1 Like

If you don’t want to transpose data to match the correct batch_shape, then you can specify dim for your plate statements. For example,

    with pyro.plate('documents', num_docs, dim=-2):
        theta = pyro.sample('theta',dist.Dirichlet(alpha))
        assert theta.shape == (num_docs, 1, num_topics)
        
        with pyro.plate('words', num_words, dim=-1):
            z = pyro.sample('z', dist.Categorical(theta))
            #assert z.shape == (num_docs, num_words)
            w = pyro.sample('w', dist.Categorical(beta[z]), obs=data)
            assert w.shape == (num_docs, num_words)
1 Like