Variational Inference for Dirichlet process clustering

Hi @vincent, sorry if I did make any confusion.

For empty guide, I did mean to use MAP inference (>"<).

If you have a sample statement like this:
a = pyro.sample("a", dist.Normal(ng_zeros(1), ng_ones(1))),
then in guide, change it to: a_map = pyro.param("a_map", Variable(torch.zeros(1), requires_grad=True)) and a = pyro.sample("a", dist.Delta(a_map)). Do this for all sample statements except the ones with obs=... parameters. (This way, we don’t have to think about which guide is suitable for our model).

For Chinese restaurant process mixture, I think that we can construct a model as follows:

def model(x):
    n = x.size(0)
    mu_list = []
    num_customers = []  # number of customers at each table
    
    for i in range(n):
        mu_i = pyro.sample("mu_{}".format(i), dist.Normal(ng_zeros(0), ng_ones(1)))
        mu_list.append(mu_i)

        if i == 0:
            z_i = 0  # first customer always sits at table 0
        else:
            probs = Variable(torch.Tensor([c/(i+alpha) for c in (num_customers + [alpha])]))
            z_i = pyro.sample("z_{}".format(i), dist.Categorical(probs)).data[0]  # sample which table the new customer will sit

        num_customers.append(0)
        num_customers[z_i] += 1  # add 1 to that table
        pyro.sample("x_{}".format(i), dist.Normal(mu_list[z_i], ng_ones(1)), obs=x[i])

Note that both model and guide in SVI use the same x at each step: svi.step(x).