Hi all, I want to make a model of Chinese restaurant process(CRP) based on this blog which is correlated to Dirichlet process. From that blog, this works as follows:
- Imagine a restaurant where all your friends went to eat yesterday…
- Initially, the restaurant is empty.
- The first person to enter (Alice) sits down at a table (selects a group). She then orders food for the table (i.e., she selects parameters for the group); everyone else who joins the table will then be limited to eating from the food she ordered.
- The second person to enter (Bob) sits down at a table. Which table does he sit at? With probability α/(1+α) he sits down at a new table (i.e., selects a new group) and orders food for the table; with probability 1/(1+α) he sits with Alice and eats from the food she’s already ordered (i.e., he’s in the same group as Alice).
…- The (n+1)-st person sits down at a new table with probability α/(n+α), and at table k with probability nk/(n+α), where nk is the number of people currently sitting at table k.
The rsult of this process is assignment of customer to table:
chinese_restaurant_process(num_customers = 10, alpha = 1)
first table is table 1
1, 2, 3, 4, 3, 3, 2, 1, 4, 3 # table assignments from run 1
1, 1, 1, 1, 1, 1, 2, 2, 1, 3 # table assignments from run 2
1, 2, 2, 1, 3, 3, 2, 1, 3, 4 # table assignments from run 3
Here is my code:
import os
import numpy as np
import scipy.stats
import torch
import seaborn as sns
from torch.distributions import constraints
import matplotlib.pyplot as plt
import time
%matplotlib inline
import pyro
import pyro.distributions as dist
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, Trace_ELBO
smoke_test = ('CI' in os.environ)
pyro.enable_validation(True)
def create_dataset_crp(alpha=3,n=10):
num_customers = [] # number of customers at each table
table_assignment =[]
for i in range(n):
if i == 0:
z_i = 0 # first customer always sits at table 0
table_assignment.append(z_i)
else:
probs = torch.tensor([c/(i+alpha) for c in (num_customers + [alpha])])
z_i = dist.Categorical(probs)().item()
table_assignment.append(z_i)
num_customers.append(0)
num_customers[z_i] += 1
return torch.tensor(table_assignment)
# chinese_restaurant_process(num_customers = 100, alpha = 3)
data = create_dataset_crp(alpha = 3, n=10)
def model(data):
n = int(data.sum().item())
num_customers = [] # number of customers at each table
loc,scale = torch.zeros(1),torch.ones(1)*2
alpha=pyro.sample("alpha",dist.LogNormal(loc,scale)) # alpha must be more than zero
for i in range(n):
if i == 0:
z_i = 0 # first customer always sits at table 0
else:
probs = torch.tensor([c/(i+alpha) for c in (num_customers + [alpha])])
z_i = dist.Categorical(probs)().item()
num_customers.append(0)
num_customers[z_i] += 1
with pyro.iarange("data",len(data)):
pyro.sample("obs", dist.Categorical(probs).expand_by(data.shape), obs = data)
def guide(data):
q_loc = pyro.param("q_loc",torch.zeros(1))
q_scale = pyro.param("q_scale", torch.ones(1)*2)
pyro.sample("alpha", dist.LogNormal(q_loc,q_scale))
optim = pyro.optim.Adam({'lr': 0.001, 'betas': [0.9, 0.99]})
svi = SVI(model, guide, optim, loss=Trace_ELBO())
def main(data,num_iter=2000):
pyro.clear_param_store()
start =time.time()
losses = np.zeros(num_iter)
for i in range(num_iter):
losses[i] = svi.step(data)
if i % (num_iter//10) == 0:
print("[iteration %04d] loss: %.4f" % (i + 1, losses[i]))
elapsed_time = time.time()
print("elapsed time: %.2f" %(elapsed_time-start))
end = time.time()
print("Loop take time %.2f"%(end-start))
plt.plot(losses)
plt.show()
main(data,2000)
for name in pyro.get_param_store().get_all_param_names():
print("[%s]: %.3f" % (name, pyro.param(name).data.numpy()))
The parameter that I want to find is the concentration parameter alpha. My problem is the loss was not converging at all. Is there any advice on this?
Actually there are some discussion about Dirichlet process in this forum some time ago. But my problem is a little bit different so I opened this topic.