Chinese restaurant process model

Hi all, I want to make a model of Chinese restaurant process(CRP) based on this blog which is correlated to Dirichlet process. From that blog, this works as follows:

  1. Imagine a restaurant where all your friends went to eat yesterday…
  2. Initially, the restaurant is empty.
  3. The first person to enter (Alice) sits down at a table (selects a group). She then orders food for the table (i.e., she selects parameters for the group); everyone else who joins the table will then be limited to eating from the food she ordered.
  4. The second person to enter (Bob) sits down at a table. Which table does he sit at? With probability α/(1+α) he sits down at a new table (i.e., selects a new group) and orders food for the table; with probability 1/(1+α) he sits with Alice and eats from the food she’s already ordered (i.e., he’s in the same group as Alice).
  5. The (n+1)-st person sits down at a new table with probability α/(n+α), and at table k with probability nk/(n+α), where nk is the number of people currently sitting at table k.

The rsult of this process is assignment of customer to table:

chinese_restaurant_process(num_customers = 10, alpha = 1)
first table is table 1
1, 2, 3, 4, 3, 3, 2, 1, 4, 3 # table assignments from run 1
1, 1, 1, 1, 1, 1, 2, 2, 1, 3 # table assignments from run 2
1, 2, 2, 1, 3, 3, 2, 1, 3, 4 # table assignments from run 3

Here is my code:

import os
import numpy as np
import scipy.stats
import torch
import seaborn as sns
from torch.distributions import constraints
import matplotlib.pyplot as plt
import time

%matplotlib inline

import pyro
import pyro.distributions as dist
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, Trace_ELBO


smoke_test = ('CI' in os.environ)
pyro.enable_validation(True)

def create_dataset_crp(alpha=3,n=10):
    num_customers = [] # number of customers at each table
    table_assignment =[]
    for i in range(n):
        if i == 0:
            z_i = 0 # first customer always sits at table 0
            table_assignment.append(z_i)
        else:
            probs = torch.tensor([c/(i+alpha) for c in (num_customers + [alpha])])
            z_i = dist.Categorical(probs)().item()
            table_assignment.append(z_i)
        num_customers.append(0)
        num_customers[z_i] += 1
        
    return torch.tensor(table_assignment)


# chinese_restaurant_process(num_customers = 100, alpha = 3)
data = create_dataset_crp(alpha = 3, n=10)

def model(data):
    n = int(data.sum().item())
    num_customers = [] # number of customers at each table
    loc,scale = torch.zeros(1),torch.ones(1)*2
    alpha=pyro.sample("alpha",dist.LogNormal(loc,scale)) # alpha must be more than zero 
    for i in range(n):
        if i == 0:
            z_i = 0 # first customer always sits at table 0
        else:
            probs = torch.tensor([c/(i+alpha) for c in (num_customers + [alpha])])
            z_i = dist.Categorical(probs)().item()
        num_customers.append(0)
        num_customers[z_i] += 1
        
    with pyro.iarange("data",len(data)):
        pyro.sample("obs", dist.Categorical(probs).expand_by(data.shape), obs = data)



def guide(data):
    q_loc = pyro.param("q_loc",torch.zeros(1))
    q_scale = pyro.param("q_scale", torch.ones(1)*2)
    pyro.sample("alpha", dist.LogNormal(q_loc,q_scale))

optim = pyro.optim.Adam({'lr': 0.001, 'betas': [0.9, 0.99]})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

def main(data,num_iter=2000):
    pyro.clear_param_store()
    start =time.time()
    losses = np.zeros(num_iter)
    for i in range(num_iter):
        losses[i] = svi.step(data)
        if i % (num_iter//10) == 0:
            print("[iteration %04d] loss: %.4f" % (i + 1, losses[i]))
            elapsed_time = time.time()
            print("elapsed time: %.2f" %(elapsed_time-start))
    end = time.time()
    print("Loop take time %.2f"%(end-start))
    plt.plot(losses)
    plt.show()

main(data,2000)
for name in pyro.get_param_store().get_all_param_names():
    print("[%s]: %.3f" % (name, pyro.param(name).data.numpy()))

The parameter that I want to find is the concentration parameter alpha. My problem is the loss was not converging at all. Is there any advice on this?

Actually there are some discussion about Dirichlet process in this forum some time ago. But my problem is a little bit different so I opened this topic.

I revise the code from my previous post become like this:

def create_dataset_crp(alpha=3,n_customers=10):
    customers_in_table = [1] # number of customers at each table
    table_assignment =[0]
    n = n_customers-1
    probs = torch.tensor([1/1+alpha,alpha/1+alpha])
    for i in range(n):
        z_i = dist.Categorical(probs)().item()
        total_customer = i + 1
        if z_i in table_assignment:
            customers_in_table[z_i] += 1
        else:
            customers_in_table += [ 1]
        probs = [n/(total_customer + alpha) for n in customers_in_table+[alpha]]
        probs = torch.tensor(probs)
        table_assignment.append(z_i)
    return torch.tensor(table_assignment)

def model(data):
    # initial condition
    customers_in_table = [1] # number of customers at each table
    table_assignment =[0] # assignment of customers to no.table
    n = data.size()[0]-1 
    loc,scale = torch.zeros(1),torch.ones(1)*2
    alpha=pyro.sample("alpha",dist.LogNormal(loc,scale)) # alpha must be more than zero 
    probs = torch.tensor([1/1+alpha,alpha/1+alpha])
    for i in range(n):
        z_i = pyro.sample("sample_{}".format(i), dist.Categorical(probs),obs = data[i])
        total_customer = i + 1
        if z_i in table_assignment:
            customers_in_table[z_i] += 1
        else:
            customers_in_table += [ 1]
        probs = [n/(total_customer + alpha) for n in customers_in_table+[alpha]]
        probs = torch.tensor(probs)
        table_assignment.append(z_i)

def guide(data):
    q_loc = pyro.param("q_loc",torch.zeros(1))
    q_scale = pyro.param("q_scale", torch.ones(1)*2)
    pyro.sample("alpha", dist.LogNormal(q_loc,q_scale))
        


optim = pyro.optim.Adam({'lr': 0.001, 'betas': [0.9, 0.99]})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

def main(data,num_iter=2000):
    pyro.clear_param_store()
    start =time.time()
    losses = np.zeros(num_iter)
    for i in range(num_iter):
        losses[i] = svi.step(data)
        if i % (num_iter//10) == 0:
            print("[iteration %04d] loss: %.4f" % (i + 1, losses[i]))
            elapsed_time = time.time()
            print("elapsed time: %.2f" %(elapsed_time-start))
    end = time.time()
    print("Loop take time %.2f"%(end-start))
    plt.plot(losses)
    plt.show()

main(data,2000)

Hi @yusri_dh, let me see if I understand your problem. Are you assuming that you know the table at which each customer sits, as the contents of data? If so then this is a very easy problem (much easier than the Dirichlet process mixture model problem, where table ids are not observed).

If you’ve observed table ids then there is really no remaining randomness, so you can use an empty guide and omit momentum in your optimizer. In fact your model will look almost exactly like your data generator:

def model(data):
    alpha = pyro.param("alpha", torch.tensor(1.0),
                       constraint=constraints.positive)
    customers_in_table = [] # number of customers at each table
    for z_i in data:
        probs = torch.cat([torch.tensor(customers_in_table), alpha])
        pyro.sample("z_{}".format(i), dist.Categorical(probs), obs=z_i)
        if z_i <= len(table_assignment):
            customers_in_table[z_i] += 1.
        else:
            customers_in_table += [1.]

def guide(data):
    pass

optim = Adam({'lr': 0.01, 'betas': (0.5, 0.9)})  # very little momentum betas[0]
svi = SVI(model, guide, optim, Trace_ELBO())

Note that I’ve used the fact that dist.Categorical does not require the probs to be normalized; this helps to avoid simple math errors (like 1/1+alpha vs 1/(1+alpha)).

Thank you so much @fritzo. It working so well now. There is something that I want to confirm. if I just use a model with an empty guide does it mean it is trying to find the MAP of parameter alpha?

yes

Thank you @jpchen.

Here is complete working example of my code. I just change a little bit from fritzo code:

def create_dataset_crp(alpha=3,n_customers=10):
    customers_in_table = [1] # number of customers at each table
    table_assignment =[0]
    n = n_customers-1
    probs = torch.tensor([1/1+alpha,alpha/1+alpha])
    for i in range(n):
        z_i = dist.Categorical(probs)().item()
        total_customer = i + 1
        if z_i in table_assignment:
            customers_in_table[z_i] += 1
        else:
            customers_in_table += [ 1]
        probs = [n/(total_customer + alpha) for n in customers_in_table+[alpha]]
        probs = torch.tensor(probs)
        table_assignment.append(z_i)
#         print ("customers_in_table: ",customers_in_table)
#         print("table_assignment: ",table_assignment)
    return torch.tensor(table_assignment)

data = create_dataset_crp(alpha=33,n_customers=100)
print(data)

def model(data):
    alpha = pyro.param("alpha", torch.tensor(1.0),
                       constraint=constraints.positive)
    customers_in_table = [] # number of customers at each table
    table_assignment =[]
    i=0
    for z_i in data:
        probs = torch.cat((torch.tensor(customers_in_table), alpha.view(-1)))
#         print("probs",probs)
        pyro.sample("z_{}".format(i), dist.Categorical(probs), obs=z_i)
        i += 1
        if z_i in table_assignment:
            customers_in_table[z_i] += 1.
        else:
            customers_in_table += [1.]
        table_assignment.append(z_i)
#         print(table_assignment)

def guide(data):
    pass

optim = pyro.optim.Adam({'lr': 0.01, 'betas': [0.5, 0.9]})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

def main(data,num_iter=2000):
    pyro.clear_param_store()
    start =time.time()
    losses = np.zeros(num_iter)
    for i in range(num_iter):
        losses[i] = svi.step(data)
        if i % (num_iter//10) == 0:
            print("[iteration %04d] loss: %.4f" % (i + 1, losses[i]))
            elapsed_time = time.time()
            print("elapsed time: %.2f" %(elapsed_time-start))
    end = time.time()
    print("Loop take time %.2f"%(end-start))
    plt.plot(losses)
    plt.show()

main(data,400)
1 Like