# Variational Inference for Dirichlet process clustering

Hi there! This is my first time using Pyro so I am very excited to see what I can built with it.
Specifically, I am trying to do finite Dirichlet Process clustering with Variational Inference. I want to generalize this into a Chinese Restaurant Process involving an āinfiniteā number of states. But for now, I am just generating 1-D data from 3 Gaussians with proportions given by a Categorical distribution of a Dirichlet Prior, and we observe each point with a likelihood given by yet another Gaussian.

Basically, the directed graph for the generative process is as follows:

where miu_k indicate the nth mean.

The joint distribution of the same graph is given by this:

Since my joint distribution depends on a global variable(Prior), a cluster-dependent Gaussian, and a likelihood is conditioned on \$z_n\$s and miu, I am not sure what to do in the `model()` function. The [SVI_Part II] (SVI Part II: Conditional Independence, Subsampling, and Amortization ā Pyro Tutorials 1.8.4 documentation) tutorial has the following code for a similar situation:

``````def model(data):
beta = pyro.sample("beta", ...) # sample the global RV
for i in pyro.irange("locals", len(data)):
z_i = pyro.sample("z_i", ...)
# compute the parameter used to define the observation
# likelihood using the local random variable
theta_i = compute_something(z_i)
pyro.observe("obs_{}".format(i), dist.mydist,
data[i], theta_i)
``````

But I am still confused as what I should do here. Should I have a nested for-loop inside the main one in which I sample the miu_k independently? To āobserveā each data-point given miu_k and z_i, how do I use the normal distribution to give a likelihood?

Here is what I have:

``````def model(data):
alpha0 = Variable(torch.Tensor([1.0,1.0,1.0]))
prior_mu = Variable(torch.zeros([1, 1]))
prior_sigma = Variable(torch.ones([1, 1]))

for i in range(len(data)):
zn = pyro.observe("latent_proportions", dist.categorical,pi,to_one_hot(y[i].data,3))
mu = pyro.sample("latent_locations", dist.normal,prior_mu,prior_sigma )
# observe data given cluster
for k in range(k):
pyro.sample("latent_locations", dist.normal,prior_mu,prior_sigma )

pyro.observe("obs_{}".format(i), dist.normal,data[i], mu,Variable(torch.ones([1, 1])))
``````

I don really know where to go from here since the last line does not make use of `latent_locations`(mu) in my formulation. What should I do in this case, when there are two late variables each with a different range in the summation?

3 Likes

Hi Vincent, thatās an interesting model. One think to keep in mind is that all `pyro.sample()` sites need distinct names, whereas in your for loop, the same name is used for all passes through the loop.

I think I would try to lazily sample the cluster parameters, something like this (still needs work!):

``````def model(data):
alpha0 = 1.0
prior_mu = Variable(torch.zeros([1, 1]))
prior_sigma = Variable(torch.ones([1, 1]))
sigma = Variable(torch.ones([1, 1]))

cluster_means = {}  # sample this lazily
crp_counts = []  # build this incrementally
for i in range(len(data)):
# sample from a CRP
crp_weights = Variable(torch.Tensor(crp_counts + [alpha0]))
crp_weights /= crp_weights.sum()
zi = pyro.sample("z_{}".format(i), dist.categorical, crp_weights)
zi = zi.data[0]  # this should be an int, not sure I've done it right
if zi > len(crp_counts):
crp_counts.append(1)  # sit at a new table
else:
crp_counts[zi] += 1  # sit at an existing table

# lazily sample cluster mean
if zi not in cluster_means:
cluster_means[zi] = pyro.sample("mu_{}".format(zi), dist.normal, prior_mu, prior_sigma)
mui = cluster_means[zi]
pyro.observe("obs_{}".format(i), dist.normal, data[i], mui, sigma)
``````

If you get this working, we should add a Dirichlet Process Mixture Model tutorial to Pyro

3 Likes

Hi Fritz,

Thank you so much for the helpful message! I have been very busy for another project, so sorry for the late reply.

I have some follow up questions:

During training, the labels of the clusters are observed, but for testing, they arenāt So in my `guide and model` functions, should I included an āifā statement to handle this difference like in the `The Semi-Supervised VAE` tutorial?

Secondly, since I am using the mean-field approximation, i.e. my `q` depends on both global and and local variables, should there should be a `for loop` in the guide like we do in the model. The tutorial on conditional independence does this, but other donāt. Since my model is adding clusters āon the flyā, how should we represent this in the guide?

My guide is expressed in the following form.

I agree that we should add a tutorial on Dirichlet process clustering since Pyro will be very suitable for this!

I will be working on this until I figure this out. I will post the part of my code I am confused about along with the suggested changes you mentioned later.

Vincent

Hi Fritz,

So I have been working on my guide function, and I have questions that I want to ask you.

First, to get the inferred locations of the cluster means, and assignments of each data point, should there be a `pyro.param(assignment_{}.format(i))` in the for loop of the guide(meaning we declare a new parameter for each data point?) I donāt have a parameter for the cluster assignment right now. So at the end of training, I can print out the result of the inferred variational parameters to know the locations.

Secondly, I am declaring a `pyro.param("weights_{}".format(i), crp_weights_q)` for each point so that I can `pyro.sample` it:

``````  crp_weights_q = pyro.param("weights_{}".format(i), crp_weights_q)
zi_q = pyro.sample("z_{}".format(i), dist.categorical, crp_weights_q)
``````

Is this the right way to do it if my `z_{}.format` is declared the way you suggested in your previous post? Since the `crp_weights_q` is changing all the time, is this the correct way of telling pyro that there is parameter for each data point?

Thirdly, I am running a basic inference function like this:

``````
pyro.clear_param_store()
# setup the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss="ELBO", num_particles=7)

n_steps = 4000
# x is the data and y is the label
for step in range(n_steps):
svi.step(x,y)
if step % 100 == 0:
print('.', end='')
``````

My data has 1000 points and labels, and I have the following
error:

``````site z_0 is observed and should not be overwritten
``````

after the 1000 points are processed in the `guide` function. it seems like pyro wants to process the data again but since the sample function already has the same name for this sample, it complains about it. However, I donāt see any examples that clarifies this issue. Pyro should process the data which has 1000 points for `n_steps = 4000` right?

Surprisingly, if I do not `print` the `z_i` at each iteration (of the data)in the `guide` function, svi returns this error immediately,i.e. in the first step. But if I do print z_i, I will only have this error at the end of my 1000 points, and I cannot figure out what caused this strange behavior. It seems a print function in the `guide` function can put things off.

Here is my code for the guide. It is based on the `model` function you provided with some modification on the `categorical` distribution(I had to change it since this returns a one-hot encoded vector instead of an integer that I could use to index into `cluster_means_q`, I did the same in my model function as well. ) Please let me know if I did this correctly. Thank you!

``````def guide(x,y):
# locations of the clusters
mu_q_0 = Variable(torch.zeros([1, 1]),
sigma_q_0 = Variable(torch.ones([1, 1]),

# register parameters
mu_q = pyro.param("mu", mu_q_0)
sigma_q = pyro.param("sigma", sigma_q_0)

# prior on alpha, will register this later
alpha_q = (4)

cluster_means_q = {}  # sample this lazily
crp_counts_q = []  # build this incrementally

for i in range(len(x)):

crp_weights_q /= crp_weights_q.sum()
crp_weights_q = pyro.param("weights_{}".format(i), crp_weights_q)
zi_q = pyro.sample("z_{}".format(i), dist.categorical, crp_weights_q)

zi_q = torch.nonzero(zi_q.data)
zi_q = zi_q[0][0]

if zi_q >= len(crp_counts_q):
crp_counts_q.append(1)  # sit at a new table
else:
crp_counts_q[zi_q] += 1  # sit at an existing table
if zi_q not in cluster_means_q:
cluster_means_q[zi_q] = pyro.sample("mu_{}".format(zi_q), dist.normal, mu_q, sigma_q)
# print('zi',zi_q) without this the program returns an error at I=1

``````

should there be a `pyro.param(assignment_{}.format(i))` in the for loop of the guide(meaning we declare a new parameter for each data point

yes, which will learn individual assignments jointly with your weights

is this the correct way of telling pyro that there is parameter for each data point

yes, if i understand correctly, you want a sample per weight, so what you have above is the correct semantics.

`site z_0 is observed and should not be overwritten`

this means that your sample and your observes have the same names. so on the next training iteration, youre reusing the same name since the name is indexed by your inner for loop. are your observes named `z_i`?

1 Like

Hi,

For the third error:[quote=ājpchen, post:5, topic:98ā]
site z_0 is observed and should not be overwritten
[/quote]
For my `model` I have

``````zi = pyro.sample("z_{}".format(i), dist.categorical, crp_weights,obs=to_one_hot(y[i]),one_hot=False)
``````

The `obs=to_one_hot(y[i])` is the label that I have for training data. For the guide, there is No `obs` or `pyro.observe` statements.

For my `guide` I have [quote=āvincent, post:4, topic:98ā]
zi_q = pyro.sample(āz_{}ā.format(i), dist.categorical, crp_weights_q)
[/quote] which does not have `obs` statements.

From what I understand, samples like the `z_i` should be shared between the `model` and the `guide` function, so what should I do if I have observations in the `model` function where I know the label for the training data?

yeah theres your problem. your `observe` statement should have no corresponding āsampleā in the guide, since it is used in the likelihood term in SVI. however, all your other `sample` statements in your model should have corresponding `sample`s in your guide. try renaming your observe to eg what fritz suggested above:
`zi = pyro.sample("obs_{}".format(i), ... ,obs=...)`

Hi,
I am still a little confused. I have to sample `zi` and `zi_q` from a `categorical` distribution in both my model and guide. These `zi/zi_q` are the cluster assignments for each point(i.e. they are the āyā labels). It is just that for the case where I know the labels, I can add in the `obs=....` statement in the `pyro.sample` statement in the `model` function.

Now, since it is not allowed to `pyro.observe` something and `pyro.sample` the same name in both the guide and the model, how would I incorporate the knowledge I have for the label in the model, and sample the same thing in my guide?

Regarding your previous comment, my `model` function does have an

``````pyro.observe("obs_{}".format(i), dist.normal, data[i], mui, sigma)
``````

but this is for the value of the data points(i.e. the X not the label ), which is not related to the `zi`. If I change the code to `zi = pyro.sample("obs_{}".format(i), ... ,obs=...)` in my model function, then I can not sample `zi` in my guide anymore, since it complains that

`Found vars in guide but not model: {}".format(guide_vars - model_vars`

I have also been getting the following error for which I do not know what caused it. This happens when I sample from the `Categorical` distribution with one_hot=False.

``````---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/poutine/trace.py in log_pdf(self, site_filter)
66                 try:
---> 67                     site_log_p = site["log_pdf"]
68                 except KeyError:

KeyError: 'log_pdf'

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
<ipython-input-30-03a21f270021> in <module>()
11 for step in range(n_steps):
---> 12     svi.step(x,y)
13     if step % 100 == 0:
14         print('.', end='')

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
96         """
97         # get loss and compute gradients
---> 98         loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
99
100         # get active params

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/infer/elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
63         :rtype: float
64         """
---> 65         return self.which_elbo.loss_and_grads(model, guide, *args, **kwargs)

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
131         elbo = 0.0
132         # grab a trace from the generator
--> 133         for weight, model_trace, guide_trace, log_r in self._get_traces(model, guide, *args, **kwargs):
134             elbo_particle = weight * 0
135             surrogate_elbo_particle = weight * 0

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in _get_traces(self, model, guide, *args, **kwargs)
85             model_trace = prune_subsample_sites(model_trace)
86
---> 87             log_r = model_trace.log_pdf() - guide_trace.log_pdf()
88             weight = 1.0 / self.num_particles
89             yield weight, model_trace, guide_trace, log_r

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/poutine/trace.py in log_pdf(self, site_filter)
69                     args, kwargs = site["args"], site["kwargs"]
70                     site_log_p = site["fn"].log_pdf(
---> 71                         site["value"], *args, **kwargs) * site["scale"]
72                     site["log_pdf"] = site_log_p
73                 log_p += site_log_p

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/distributions/random_primitive.py in log_pdf(self, x, *args, **kwargs)
40
41     def log_pdf(self, x, *args, **kwargs):
---> 42         return self.dist_class(*args, **kwargs).log_pdf(x)
43
44     def batch_log_pdf(self, x, *args, **kwargs):

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/distributions/distribution.py in log_pdf(self, x, *args, **kwargs)
184         """
186
187     @abstractmethod

/Users/zhitingchen/anaconda/lib/python3.6/site-packages/pyro/distributions/categorical.py in batch_log_pdf(self, x)
161             else:
--> 162                 boolean_mask = torch_zeros_like(logits.data).scatter_(-1, x.data.long(), 1)

RuntimeError: Invalid index in scatter at /Users/soumith/miniconda2/conda-bld/pytorch_1502000975045/work/torch/lib/TH/generic/THTensorMath.c:515
``````

Could this be related to this?

When I have one_hot=True, since the sample from the model and guide functions might return one-hot vectors of different sizes, it complains about size mismatches like the following:

``````Model and guide dims disagree at site 'z_9': torch.Size([8]) vs torch.Size([9])
``````

This is happens because the number of clusters from the categorical distribution is always changing in the Dirichlet process. Do you have any ideas on how to fix this?

@vincent Here are some of my suggestions. Hope it help.

• Forget about `guide`. Parameterize your cluster mean distribution by some parameters (similar to what you did for the guide: using pyro.param).

• After you sample cluster means, sample a distribution P over these means using Dirichlet distribution.

• Now, go to the for loop: using P as probs for a Categorical distribution, get a sample from it. This sample will give you the corresponding cluster mean from step 2. Then sample your data point from this cluster mean (using Normal distribution). Because this step is independent, you can use `irange`.

• Try to get inference from the above model using SVI (with empty guide: which means you are doing maximum likelihood). I would suggest making it work first before defining any guide for your model. It is easy to make things wrong when building `model` as a first step.

• For Chinese restaurant process, you can consult @fritzoās suggestion.

• It would be better to diagnostic your problem if you using the latest version of Pyro. (we distinguish Categorical distribution and OneHotCategorical distribution, and remove random primitives: dist.normal,ā¦ to use `dist.Normal(mu, sigma)`).

Thanks for the suggestions.

I think my biggest problem now is that the `zi`(samples from my categorical distribution) have different sizes in my guide and my model. This happens because a Chinese restaurant process has no predetermined number of clusters, i.e. the P for the categorical is always changing as more data points are observed. This causes a mismatch between the model and the guide because I use the same names:

``````pyro.sample("z_{}".format(i)....)
``````

for both the model and the guide.

Regarding your suggestions of using maximum likelihood, are you referring this pyro/gmm.ipynb at 1bbfb48aa280dab36ea8b1ff17826665e1fd9eb6 Ā· pyro-ppl/pyro Ā· GitHub ?

I didnāt know you could have an empty guide function. I have a `pass` statement in my guide and that results in warning about vars found in model but not in guide because the samples in model are not shared with the guide.

Thanks

Hi @vincent, sorry if I did make any confusion.

For empty guide, I did mean to use MAP inference (>"<).

If you have a `sample` statement like this:
`a = pyro.sample("a", dist.Normal(ng_zeros(1), ng_ones(1)))`,
then in `guide`, change it to: `a_map = pyro.param("a_map", Variable(torch.zeros(1), requires_grad=True))` and `a = pyro.sample("a", dist.Delta(a_map))`. Do this for all `sample` statements except the ones with `obs=...` parameters. (This way, we donāt have to think about which guide is suitable for our model).

For Chinese restaurant process mixture, I think that we can construct a model as follows:

``````def model(x):
n = x.size(0)
mu_list = []
num_customers = []  # number of customers at each table

for i in range(n):
mu_i = pyro.sample("mu_{}".format(i), dist.Normal(ng_zeros(0), ng_ones(1)))
mu_list.append(mu_i)

if i == 0:
z_i = 0  # first customer always sits at table 0
else:
probs = Variable(torch.Tensor([c/(i+alpha) for c in (num_customers + [alpha])]))
z_i = pyro.sample("z_{}".format(i), dist.Categorical(probs)).data[0]  # sample which table the new customer will sit

num_customers.append(0)
num_customers[z_i] += 1  # add 1 to that table
pyro.sample("x_{}".format(i), dist.Normal(mu_list[z_i], ng_ones(1)), obs=x[i])
``````

Note that both `model` and `guide` in SVI use the same `x` at each step: `svi.step(x)`.

I currently have something similar to your above formulation. However, in order to plot the inferred means at the end of the day, I need to store them into a pyro.params(āmeansā,ā¦) statement. Since the means are constantly growing, how can I do this?

I am currently initializing the means vector with a fixed number of means(with more means than I need so we wonāt overflow), and only index into this vector with `z_i` as you have. I plan to plot only the top three means as ranked by the proportion stored in the `probs` vector, which is the mixture proportions. Ideally, the reals clusters will have high weights, and the rest will go towards zero.

Your formulation uses a list to store the means, so how do you output the result of your inference if you donāt register it into a param?

Thank you

@vincent if you use MAP, then I guess `pyro.param("mu_{}_map".format(i))` is what you are looking for.

1 Like

Hi @fehiepsi,

Before, we had to sample zi from a categorical like this:

``````        z = pyro.sample("z_{}".format(i), dist.categorical, (crp_weights_q))
``````

Then using zi to index into another list or Tensor contain the means like this:

``````        mu_z = mu.index_select(0, z)
pyro.sample('x', Normal(mu_z, sigma.expand_as(mu_z)), obs=data)
``````

Your MAP technique is a convenient way of decoupling the sampling process in the guide from any structure there may be in the model. This is why you said this technique works for any guide right?

``````    mu_map = pyro.param("mu_map_{}".format(z), Variable(torch.zeros(1), requires_grad=True))
mu = pyro.sample("mu_{}".format(z), dist.Delta(mu_map))
``````

However, since MAP technique does not use samples from the categorical to index into anything, this means we wonāt be having any `zi` in the guide. Can you give a recommendation for a MAP technique for the ```zi`` as well?

my attempt was the following:

``````         zi_map = softplus(pyro.param("z_map_{}".format(i), Variable(torch.Tensor([1]), requires_grad=True)))
zi_q = pyro.sample("z_{}".format(i), dist.categorical, (zi_map))
pyro.param("z_{}_map".format(i),zi_map)

``````

I have two ways of solving this:

Option 1 in the code makes sense to me because I need the crp_weights to see which cluster assignments have small weights to filter them out. However, this does not compile(Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.)

Option 2 compiles but put everything into one cluster, i.e. zi are always 1.

It might be easier to look at the guide I have:

``````def guide(x,y):
# prior on alpha, will register this later
alpha_q = (2)

cluster_means_q = {}  # sample this lazily
crp_counts_q = []  # build this incrementally

for i in pyro.irange("data_loop", len(x)):

#option 1
zi_map = pyro.param("z_map",crp_weights_q)

#option 2

zi_q = pyro.sample("z_{}".format(i), dist.categorical, (zi_map))
#convert into an int to index
z_ind = ((zi_q.data[0]))

if z_ind >= len(crp_counts_q):
crp_counts_q.append(1)  # sit at a new table
else:
crp_counts_q[z_ind] += 1  # sit at an existing table

if z_ind not in cluster_means_q: #is this if statement necessary
mu = pyro.sample("mu_{}".format(z_ind), dist.Delta(mu_map))
cluster_means_q[z_ind] = mu
``````

Thank you so much for the help. I donāt want to be spamming the forum, so please let me know if I am.

Hi @vincent, the discussion is helpful for other people too, so donāt have to worry about it. About MAP, it is just a starting step which helps get rid of thinking about `guide`. For example, the corresponding `guide` for my last model is:

``````def MAP_guide(x):
n = x.size(0)
mu_list = []
num_customers = []  # number of customers at each table

for i in range(n):
mu_i = pyro.sample("mu_{}".format(i), dist.Delta(mu_i_MAP))
mu_list.append(mu_i)

if i == 0:
z_i = 0  # first customer always sits at table 0
else:
probs = Variable(torch.Tensor([c/(i+alpha) for c in (num_customers + [alpha])]))
z_i = pyro.sample("z_{}".format(i), dist.Categorical(probs)).data[0]  # sample which table the new customer will sit

num_customers.append(0)
num_customers[z_i] += 1  # add 1 to that table
``````

MAP does not work with discrete distributions like Categorical. If you want to get information about `z`, you can define a simple guide for it. For example

``````def guide(x):
n = x.size(0)

for i in range(n):
mu_i = pyro.sample("mu_{}".format(i), dist.Delta(mu_i_MAP))

if i == 0:
z_i = 0  # first customer always sits at table 0
else:
z_i_probs = softmax(z_i_probs_unconstrained)
z_i = pyro.sample("z_{}".format(i), dist.Categorical(z_i_probs)).data[0]
``````

Now, params `mu_{}_MAP`, `z_{}_probs_unconstrained` will be learnt to best describe the posterior distribution. You can also define a more complicated guide which takes information from `x`. Please let me know if it works.

1 Like

@fehiepsi

I didnāt want to make a new topic since itās a similar question, what is the way to update the local variable when it depends on all observations, except the current one, so as in:

`P(z | z_{-k})`

This is modeled after the collapsed variational inference, where the weights are marginalized out: