Variational Inference for Dirichlet process clustering

Hi there! This is my first time using Pyro so I am very excited to see what I can built with it.:slight_smile:
Specifically, I am trying to do finite Dirichlet Process clustering with Variational Inference. I want to generalize this into a Chinese Restaurant Process involving an “infinite” number of states. But for now, I am just generating 1-D data from 3 Gaussians with proportions given by a Categorical distribution of a Dirichlet Prior, and we observe each point with a likelihood given by yet another Gaussian.

Basically, the directed graph for the generative process is as follows:

where miu_k indicate the nth mean.

The joint distribution of the same graph is given by this:

Since my joint distribution depends on a global variable(Prior), a cluster-dependent Gaussian, and a likelihood is conditioned on $z_n$s and miu, I am not sure what to do in the model() function. The [SVI_Part II] (SVI Part II: Conditional Independence, Subsampling, and Amortization — Pyro Tutorials 1.8.4 documentation) tutorial has the following code for a similar situation:

def model(data):
beta = pyro.sample("beta", ...) # sample the global RV
for i in pyro.irange("locals", len(data)):
    z_i = pyro.sample("z_i", ...)
    # compute the parameter used to define the observation
    # likelihood using the local random variable
    theta_i = compute_something(z_i)
    pyro.observe("obs_{}".format(i), dist.mydist,
                 data[i], theta_i)

But I am still confused as what I should do here. Should I have a nested for-loop inside the main one in which I sample the miu_k independently? To “observe” each data-point given miu_k and z_i, how do I use the normal distribution to give a likelihood?

Here is what I have:

def model(data):
    alpha0 = Variable(torch.Tensor([1.0,1.0,1.0]))
    prior_mu = Variable(torch.zeros([1, 1]))
    prior_sigma = Variable(torch.ones([1, 1]))    

    for i in range(len(data)):
        zn = pyro.observe("latent_proportions", dist.categorical,pi,to_one_hot(y[i].data,3))
        mu = pyro.sample("latent_locations", dist.normal,prior_mu,prior_sigma )
        # observe data given cluster
        for k in range(k):
             pyro.sample("latent_locations", dist.normal,prior_mu,prior_sigma )
               
        pyro.observe("obs_{}".format(i), dist.normal,data[i], mu,Variable(torch.ones([1, 1])))

I don really know where to go from here since the last line does not make use of latent_locations(mu) in my formulation. What should I do in this case, when there are two late variables each with a different range in the summation?

3 Likes

Hi Vincent, that’s an interesting model. One think to keep in mind is that all pyro.sample() sites need distinct names, whereas in your for loop, the same name is used for all passes through the loop.

I think I would try to lazily sample the cluster parameters, something like this (still needs work!):

def model(data):
    alpha0 = 1.0
    prior_mu = Variable(torch.zeros([1, 1]))
    prior_sigma = Variable(torch.ones([1, 1]))
    sigma = Variable(torch.ones([1, 1]))

    cluster_means = {}  # sample this lazily
    crp_counts = []  # build this incrementally
    for i in range(len(data)):
        # sample from a CRP
        crp_weights = Variable(torch.Tensor(crp_counts + [alpha0]))
        crp_weights /= crp_weights.sum()
        zi = pyro.sample("z_{}".format(i), dist.categorical, crp_weights)
        zi = zi.data[0]  # this should be an int, not sure I've done it right
        if zi > len(crp_counts):
            crp_counts.append(1)  # sit at a new table
        else:
            crp_counts[zi] += 1  # sit at an existing table

        # lazily sample cluster mean
        if zi not in cluster_means:
            cluster_means[zi] = pyro.sample("mu_{}".format(zi), dist.normal, prior_mu, prior_sigma)
        mui = cluster_means[zi]
        pyro.observe("obs_{}".format(i), dist.normal, data[i], mui, sigma)

If you get this working, we should add a Dirichlet Process Mixture Model tutorial to Pyro :smile:

3 Likes

Hi Fritz,

Thank you so much for the helpful message! :+1:I have been very busy for another project, so sorry for the late reply.:slight_smile:

I have some follow up questions:

During training, the labels of the clusters are observed, but for testing, they aren’t So in my guide and model functions, should I included an ‘if’ statement to handle this difference like in the The Semi-Supervised VAE tutorial?

Secondly, since I am using the mean-field approximation, i.e. my q depends on both global and and local variables, should there should be a for loop in the guide like we do in the model. The tutorial on conditional independence does this, but other don’t. Since my model is adding clusters “on the fly”, how should we represent this in the guide?

My guide is expressed in the following form.

I agree that we should add a tutorial on Dirichlet process clustering since Pyro will be very suitable for this!

I will be working on this until I figure this out. I will post the part of my code I am confused about along with the suggested changes you mentioned later.

Vincent

Hi Fritz,

So I have been working on my guide function, and I have questions that I want to ask you.

First, to get the inferred locations of the cluster means, and assignments of each data point, should there be a pyro.param(assignment_{}.format(i)) in the for loop of the guide(meaning we declare a new parameter for each data point?) I don’t have a parameter for the cluster assignment right now. So at the end of training, I can print out the result of the inferred variational parameters to know the locations.

Secondly, I am declaring a pyro.param("weights_{}".format(i), crp_weights_q) for each point so that I can pyro.sample it:

  crp_weights_q = pyro.param("weights_{}".format(i), crp_weights_q)
  zi_q = pyro.sample("z_{}".format(i), dist.categorical, crp_weights_q)

Is this the right way to do it if my z_{}.format is declared the way you suggested in your previous post? Since the crp_weights_q is changing all the time, is this the correct way of telling pyro that there is parameter for each data point?

Thirdly, I am running a basic inference function like this:


pyro.clear_param_store()
# setup the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss="ELBO", num_particles=7)

n_steps = 4000
# do gradient steps
# x is the data and y is the label
for step in range(n_steps):
    svi.step(x,y)
    if step % 100 == 0:
        print('.', end='')

My data has 1000 points and labels, and I have the following
error:

site z_0 is observed and should not be overwritten

after the 1000 points are processed in the guide function. it seems like pyro wants to process the data again but since the sample function already has the same name for this sample, it complains about it. However, I don’t see any examples that clarifies this issue. Pyro should process the data which has 1000 points for n_steps = 4000 right?

Surprisingly, if I do not print the z_i at each iteration (of the data)in the guide function, svi returns this error immediately,i.e. in the first step. But if I do print z_i, I will only have this error at the end of my 1000 points, and I cannot figure out what caused this strange behavior. It seems a print function in the guide function can put things off.

Here is my code for the guide. It is based on the model function you provided with some modification on the categorical distribution(I had to change it since this returns a one-hot encoded vector instead of an integer that I could use to index into cluster_means_q, I did the same in my model function as well. ) Please let me know if I did this correctly. Thank you!

def guide(x,y):
    # locations of the clusters
    mu_q_0 = Variable(torch.zeros([1, 1]),
                             requires_grad=True)
    sigma_q_0 = Variable(torch.ones([1, 1]),
                             requires_grad=True)

    # register parameters
    mu_q = pyro.param("mu", mu_q_0)
    sigma_q = pyro.param("sigma", sigma_q_0)

    # prior on alpha, will register this later
    alpha_q = (4)

    cluster_means_q = {}  # sample this lazily
    crp_counts_q = []  # build this incrementally


    for i in range(len(x)):

        crp_weights_q = Variable(torch.Tensor(crp_counts_q + [alpha_q]),requires_grad=True)

        crp_weights_q /= crp_weights_q.sum()
        crp_weights_q = pyro.param("weights_{}".format(i), crp_weights_q)
        zi_q = pyro.sample("z_{}".format(i), dist.categorical, crp_weights_q)

        zi_q = torch.nonzero(zi_q.data)
        zi_q = zi_q[0][0]

        if zi_q >= len(crp_counts_q):
            crp_counts_q.append(1)  # sit at a new table
        else:
            crp_counts_q[zi_q] += 1  # sit at an existing table
        if zi_q not in cluster_means_q:
            cluster_means_q[zi_q] = pyro.sample("mu_{}".format(zi_q), dist.normal, mu_q, sigma_q)
        # print('zi',zi_q) without this the program returns an error at I=1

to address some of your errors:

should there be a pyro.param(assignment_{}.format(i)) in the for loop of the guide(meaning we declare a new parameter for each data point

yes, which will learn individual assignments jointly with your weights

is this the correct way of telling pyro that there is parameter for each data point

yes, if i understand correctly, you want a sample per weight, so what you have above is the correct semantics.

site z_0 is observed and should not be overwritten

this means that your sample and your observes have the same names. so on the next training iteration, youre reusing the same name since the name is indexed by your inner for loop. are your observes named z_i?

1 Like

Hi,

For the third error:[quote=“jpchen, post:5, topic:98”]
site z_0 is observed and should not be overwritten
[/quote]
For my model I have

zi = pyro.sample("z_{}".format(i), dist.categorical, crp_weights,obs=to_one_hot(y[i]),one_hot=False)

The obs=to_one_hot(y[i]) is the label that I have for training data. For the guide, there is No obs or pyro.observe statements.

For my guide I have [quote=“vincent, post:4, topic:98”]
zi_q = pyro.sample(“z_{}”.format(i), dist.categorical, crp_weights_q)
[/quote] which does not have obs statements.

From what I understand, samples like the z_i should be shared between the model and the guide function, so what should I do if I have observations in the model function where I know the label for the training data?

yeah theres your problem. your observe statement should have no corresponding ‘sample’ in the guide, since it is used in the likelihood term in SVI. however, all your other sample statements in your model should have corresponding samples in your guide. try renaming your observe to eg what fritz suggested above:
zi = pyro.sample("obs_{}".format(i), ... ,obs=...)

Hi,
I am still a little confused. I have to sample zi and zi_q from a categorical distribution in both my model and guide. These zi/zi_q are the cluster assignments for each point(i.e. they are the “y” labels). It is just that for the case where I know the labels, I can add in the obs=.... statement in the pyro.sample statement in the model function.

Now, since it is not allowed to pyro.observe something and pyro.sample the same name in both the guide and the model, how would I incorporate the knowledge I have for the label in the model, and sample the same thing in my guide?

Regarding your previous comment, my model function does have an

pyro.observe("obs_{}".format(i), dist.normal, data[i], mui, sigma)

but this is for the value of the data points(i.e. the X not the label ), which is not related to the zi. If I change the code to zi = pyro.sample("obs_{}".format(i), ... ,obs=...) in my model function, then I can not sample zi in my guide anymore, since it complains that

Found vars in guide but not model: {}".format(guide_vars - model_vars

I have also been getting the following error for which I do not know what caused it. This happens when I sample from the Categorical distribution with one_hot=False.

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/poutine/trace.py in log_pdf(self, site_filter)
     66                 try:
---> 67                     site_log_p = site["log_pdf"]
     68                 except KeyError:

KeyError: 'log_pdf'

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
<ipython-input-30-03a21f270021> in <module>()
     10 # do gradient steps
     11 for step in range(n_steps):
---> 12     svi.step(x,y)
     13     if step % 100 == 0:
     14         print('.', end='')

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
     96         """
     97         # get loss and compute gradients
---> 98         loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
     99 
    100         # get active params

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/infer/elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
     63         :rtype: float
     64         """
---> 65         return self.which_elbo.loss_and_grads(model, guide, *args, **kwargs)

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
    131         elbo = 0.0
    132         # grab a trace from the generator
--> 133         for weight, model_trace, guide_trace, log_r in self._get_traces(model, guide, *args, **kwargs):
    134             elbo_particle = weight * 0
    135             surrogate_elbo_particle = weight * 0

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in _get_traces(self, model, guide, *args, **kwargs)
     85             model_trace = prune_subsample_sites(model_trace)
     86 
---> 87             log_r = model_trace.log_pdf() - guide_trace.log_pdf()
     88             weight = 1.0 / self.num_particles
     89             yield weight, model_trace, guide_trace, log_r

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/poutine/trace.py in log_pdf(self, site_filter)
     69                     args, kwargs = site["args"], site["kwargs"]
     70                     site_log_p = site["fn"].log_pdf(
---> 71                         site["value"], *args, **kwargs) * site["scale"]
     72                     site["log_pdf"] = site_log_p
     73                 log_p += site_log_p

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/distributions/random_primitive.py in log_pdf(self, x, *args, **kwargs)
     40 
     41     def log_pdf(self, x, *args, **kwargs):
---> 42         return self.dist_class(*args, **kwargs).log_pdf(x)
     43 
     44     def batch_log_pdf(self, x, *args, **kwargs):

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/distributions/distribution.py in log_pdf(self, x, *args, **kwargs)
    183         :rtype: torch.autograd.Variable
    184         """
--> 185         return torch.sum(self.batch_log_pdf(x, *args, **kwargs))
    186 
    187     @abstractmethod

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/distributions/categorical.py in batch_log_pdf(self, x)
    160                 boolean_mask = x
    161             else:
--> 162                 boolean_mask = torch_zeros_like(logits.data).scatter_(-1, x.data.long(), 1)
    163         boolean_mask = boolean_mask.cuda() if logits.is_cuda else boolean_mask.cpu()
    164         if not isinstance(boolean_mask, Variable):

RuntimeError: Invalid index in scatter at /Users/soumith/miniconda2/conda-bld/pytorch_1502000975045/work/torch/lib/TH/generic/THTensorMath.c:515

Could this be related to this?

When I have one_hot=True, since the sample from the model and guide functions might return one-hot vectors of different sizes, it complains about size mismatches like the following:

Model and guide dims disagree at site 'z_9': torch.Size([8]) vs torch.Size([9])

This is happens because the number of clusters from the categorical distribution is always changing in the Dirichlet process. Do you have any ideas on how to fix this?

@vincent Here are some of my suggestions. Hope it help.

  • Forget about guide. Parameterize your cluster mean distribution by some parameters (similar to what you did for the guide: using pyro.param).

  • After you sample cluster means, sample a distribution P over these means using Dirichlet distribution.

  • Now, go to the for loop: using P as probs for a Categorical distribution, get a sample from it. This sample will give you the corresponding cluster mean from step 2. Then sample your data point from this cluster mean (using Normal distribution). Because this step is independent, you can use irange.

  • Try to get inference from the above model using SVI (with empty guide: which means you are doing maximum likelihood). I would suggest making it work first before defining any guide for your model. It is easy to make things wrong when building model as a first step.

  • For Chinese restaurant process, you can consult @fritzo’s suggestion.

  • It would be better to diagnostic your problem if you using the latest version of Pyro.:slight_smile: (we distinguish Categorical distribution and OneHotCategorical distribution, and remove random primitives: dist.normal,… to use dist.Normal(mu, sigma)).

Thanks for the suggestions.

I think my biggest problem now is that the zi(samples from my categorical distribution) have different sizes in my guide and my model. This happens because a Chinese restaurant process has no predetermined number of clusters, i.e. the P for the categorical is always changing as more data points are observed. This causes a mismatch between the model and the guide because I use the same names:

pyro.sample("z_{}".format(i)....)

for both the model and the guide.

Regarding your suggestions of using maximum likelihood, are you referring this pyro/gmm.ipynb at 1bbfb48aa280dab36ea8b1ff17826665e1fd9eb6 · pyro-ppl/pyro · GitHub ?

I didn’t know you could have an empty guide function. I have a pass statement in my guide and that results in warning about vars found in model but not in guide because the samples in model are not shared with the guide.

Thanks

Hi @vincent, sorry if I did make any confusion.

For empty guide, I did mean to use MAP inference (>"<).

If you have a sample statement like this:
a = pyro.sample("a", dist.Normal(ng_zeros(1), ng_ones(1))),
then in guide, change it to: a_map = pyro.param("a_map", Variable(torch.zeros(1), requires_grad=True)) and a = pyro.sample("a", dist.Delta(a_map)). Do this for all sample statements except the ones with obs=... parameters. (This way, we don’t have to think about which guide is suitable for our model).

For Chinese restaurant process mixture, I think that we can construct a model as follows:

def model(x):
    n = x.size(0)
    mu_list = []
    num_customers = []  # number of customers at each table
    
    for i in range(n):
        mu_i = pyro.sample("mu_{}".format(i), dist.Normal(ng_zeros(0), ng_ones(1)))
        mu_list.append(mu_i)

        if i == 0:
            z_i = 0  # first customer always sits at table 0
        else:
            probs = Variable(torch.Tensor([c/(i+alpha) for c in (num_customers + [alpha])]))
            z_i = pyro.sample("z_{}".format(i), dist.Categorical(probs)).data[0]  # sample which table the new customer will sit

        num_customers.append(0)
        num_customers[z_i] += 1  # add 1 to that table
        pyro.sample("x_{}".format(i), dist.Normal(mu_list[z_i], ng_ones(1)), obs=x[i])

Note that both model and guide in SVI use the same x at each step: svi.step(x).

Hi @fehiepsi @fritzo ,

I currently have something similar to your above formulation. However, in order to plot the inferred means at the end of the day, I need to store them into a pyro.params(‘means’,…) statement. Since the means are constantly growing, how can I do this?

I am currently initializing the means vector with a fixed number of means(with more means than I need so we won’t overflow), and only index into this vector with z_i as you have. I plan to plot only the top three means as ranked by the proportion stored in the probs vector, which is the mixture proportions. Ideally, the reals clusters will have high weights, and the rest will go towards zero.

Your formulation uses a list to store the means, so how do you output the result of your inference if you don’t register it into a param?

Thank you

@vincent if you use MAP, then I guess pyro.param("mu_{}_map".format(i)) is what you are looking for.

1 Like

Hi @fehiepsi,

Thanks for your suggestions

Before, we had to sample zi from a categorical like this:

        z = pyro.sample("z_{}".format(i), dist.categorical, (crp_weights_q))

Then using zi to index into another list or Tensor contain the means like this:

        mu_z = mu.index_select(0, z)
        pyro.sample('x', Normal(mu_z, sigma.expand_as(mu_z)), obs=data)

Your MAP technique is a convenient way of decoupling the sampling process in the guide from any structure there may be in the model. This is why you said this technique works for any guide right?

    mu_map = pyro.param("mu_map_{}".format(z), Variable(torch.zeros(1), requires_grad=True))  
    mu = pyro.sample("mu_{}".format(z), dist.Delta(mu_map))

However, since MAP technique does not use samples from the categorical to index into anything, this means we won’t be having any zi in the guide. Can you give a recommendation for a MAP technique for the ```zi`` as well?

my attempt was the following:

         zi_map = softplus(pyro.param("z_map_{}".format(i), Variable(torch.Tensor([1]), requires_grad=True)))  
         zi_q = pyro.sample("z_{}".format(i), dist.categorical, (zi_map))
         pyro.param("z_{}_map".format(i),zi_map)

I have two ways of solving this:

Option 1 in the code makes sense to me because I need the crp_weights to see which cluster assignments have small weights to filter them out. However, this does not compile(Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.)

Option 2 compiles but put everything into one cluster, i.e. zi are always 1.

It might be easier to look at the guide I have:

def guide(x,y):
    # prior on alpha, will register this later
    alpha_q = (2)

    cluster_means_q = {}  # sample this lazily
        crp_counts_q = []  # build this incrementally
    
    for i in pyro.irange("data_loop", len(x)):
        crp_weights_q = softmax(Variable(torch.Tensor(crp_counts_q + [alpha_q]),requires_grad=True))

        #option 1
        zi_map = pyro.param("z_map",crp_weights_q)

        #option 2
        zi_map softplus(pyro.param("z_map_{}".format(i), Variable(torch.Tensor([1]), requires_grad=True)))  
  
    zi_q = pyro.sample("z_{}".format(i), dist.categorical, (zi_map))
    #convert into an int to index
    z_ind = ((zi_q.data[0]))


    if z_ind >= len(crp_counts_q):
            crp_counts_q.append(1)  # sit at a new table
        else:
            crp_counts_q[z_ind] += 1  # sit at an existing table
            
        if z_ind not in cluster_means_q: #is this if statement necessary
            mu_map = pyro.param("mu_map_{}".format(i), Variable(torch.zeros(1), requires_grad=True))  
            mu = pyro.sample("mu_{}".format(z_ind), dist.Delta(mu_map))
            cluster_means_q[z_ind] = mu

Thank you so much for the help. I don’t want to be spamming the forum, so please let me know if I am.

Hi @vincent, the discussion is helpful for other people too, so don’t have to worry about it. About MAP, it is just a starting step which helps get rid of thinking about guide. For example, the corresponding guide for my last model is:

def MAP_guide(x):
    n = x.size(0)
    mu_list = []
    num_customers = []  # number of customers at each table
    
    for i in range(n):
        mu_i_MAP = pyro.param("mu_{}_MAP".format(i), Variable(torch.zeros(1), requires_grad=True))
        mu_i = pyro.sample("mu_{}".format(i), dist.Delta(mu_i_MAP))
        mu_list.append(mu_i)

        if i == 0:
            z_i = 0  # first customer always sits at table 0
        else:
            probs = Variable(torch.Tensor([c/(i+alpha) for c in (num_customers + [alpha])]))
            z_i = pyro.sample("z_{}".format(i), dist.Categorical(probs)).data[0]  # sample which table the new customer will sit

        num_customers.append(0)
        num_customers[z_i] += 1  # add 1 to that table

MAP does not work with discrete distributions like Categorical. If you want to get information about z, you can define a simple guide for it. For example

def guide(x):
    n = x.size(0)
    
    for i in range(n):
        mu_i_MAP = pyro.param("mu_{}_MAP".format(i), Variable(torch.zeros(1), requires_grad=True))
        mu_i = pyro.sample("mu_{}".format(i), dist.Delta(mu_i_MAP))

        if i == 0:
            z_i = 0  # first customer always sits at table 0
        else:
            z_i_probs_unconstrained = pyro.param("z_{}_probs_unconstrained".format(i), Variable(torch.ones(i+1), requires_grad=True))
            z_i_probs = softmax(z_i_probs_unconstrained)
            z_i = pyro.sample("z_{}".format(i), dist.Categorical(z_i_probs)).data[0]

Now, params mu_{}_MAP, z_{}_probs_unconstrained will be learnt to best describe the posterior distribution. You can also define a more complicated guide which takes information from x. Please let me know if it works. :slight_smile:

1 Like

@fehiepsi

I didn’t want to make a new topic since it’s a similar question, what is the way to update the local variable when it depends on all observations, except the current one, so as in:

P(z | z_{-k})

This is modeled after the collapsed variational inference, where the weights are marginalized out:

http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.79.6178&rep=rep1&type=pdf

@rmehta1987 I am not familiar with collapsed variational inference. I think it is better to open a new topic and state the problem more concretely (with some code, etc.), so other people can help you.