Dirichlet Process Mixture Model Modification (GPU Code, Use of Particles, Discrete Enumeration)

Hi,

my first post here. First of all thank you and your team for the great contribution with Pyro!

I modified the code in (DPMM Tutorial) as follows:

  • the code runs on GPU,
  • more than a single particle can be used during training (with num_particles=10),
  • TraceEnum_ELBO is employed instead of Trace_ELBO.

My questions are:

  1. The way I indexed the variables in model’s data plate is not elegant (and it does not work, if num_particles=1). I am aware that more elegant ways exist with the use of vindex. I could not come with a more elegant solution myself. I appreciate here any suggestions.

  2. I am aware that I cannot use mini-batching (Pytorch’s with DataLoader(shuffle=True)) during each each training iteration since the z’s are locally assigned to an appropriate x instance. Subsampling the whole dataset and using plate’s indices (with subsample_size) is not option in my use-case, since the whole dataset does not necessarily fit into GPU memory. Is there any solution/example/tutorial related to this issue?

  3. I observe time-to-time mode-collapsing issues with a bigger dataset by just increasing num_particles parameter. I tried to hinder these issues with smart-initialization within guide:
    tau = pyro.param('tau', init_tensor=centers_init) and reducing the variation around my initialization
    q_mu = pyro.sample("mu", MultivariateNormal(tau, 0.1*torch.eye(subspace_dim, device=device))) Should I respectively make my prior more informative in this sense, e.g. within model where mu is sampled? Any suggestion in this respect?

Code:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.distributions import constraints

import pyro
from pyro.distributions import *
from pyro.infer import SVI, config_enumerate, TraceEnum_ELBO
from pyro.optim import Adam

assert pyro.__version__.startswith('1.8.6')
pyro.set_rng_seed(0)

device = torch.device("cuda")

data = torch.cat((MultivariateNormal(-8 * torch.ones(2), torch.eye(2)).sample([50]),
                  MultivariateNormal(8 * torch.ones(2), torch.eye(2)).sample([50]),
                  MultivariateNormal(torch.tensor([1.5, 2]), torch.eye(2)).sample([50]),
                  MultivariateNormal(torch.tensor([-0.5, 1]), torch.eye(2)).sample([50])))

data = data.to(device)
# plt.scatter(data[:, 0], data[:, 1])
# plt.title("Data Samples from Mixture of 4 Gaussians")
# plt.show()
N = data.shape[0]
num_particles=10

########################################   
def mix_weights(beta):
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)

########################################   

def model(data):
    with pyro.plate("beta_plate", T-1, device=device):
        beta = pyro.sample("beta", Beta(1, alpha))

    with pyro.plate("mu_plate", T, device=device):
        mu = pyro.sample("mu", MultivariateNormal(torch.zeros(2, device=device), 5 * torch.eye(2, device=device)))

    with pyro.plate("data", N, device=device):
        #dim=-4 is an enumeration dim for z (e.g. T = 6 clusters)
        #dim=-3 is a particle vectorization (e.g. num_particles = 10 particles)
        #dim=-2 is allocated for "data" plate (1 value broadcasted over a batch)
        #dim=-1 is allocated as event dimension (2 values)
        z = pyro.sample("z", Categorical(mix_weights(beta).unsqueeze(-2)))
        pyro.sample("obs", MultivariateNormal(mu[torch.arange(num_particles).reshape(num_particles, 1), z, :],
                                              torch.eye(2, device=device)), obs=data)

########################################        
def guide(data):
    kappa = pyro.param('kappa', lambda: Uniform(torch.tensor(0., device=device), torch.tensor(2., device=device)).sample([T-1]),
                                        constraint=constraints.positive)
    tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(2, device=device), 3 * torch.eye(2, device=device)).sample([T]))
    phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T, device=device)).sample([N]), constraint=constraints.simplex)

    with pyro.plate("beta_plate", T-1, device=device):
        q_beta = pyro.sample("beta", Beta(torch.ones(T-1, device=device), kappa))

    with pyro.plate("mu_plate", T, device=device):
        q_mu = pyro.sample("mu", MultivariateNormal(tau, torch.eye(2, device=device)))

    with pyro.plate("data", N, device=device):
        z = pyro.sample("z", Categorical(phi))
        
        
T = 6
optim = Adam({"lr": 0.05})
svi = SVI(model, config_enumerate(guide, 'parallel'), optim, loss=TraceEnum_ELBO(max_plate_nesting=1,
                                                                                 num_particles=num_particles,
                                                                                 vectorize_particles=True))
losses = []

def train(num_iterations):
    pyro.clear_param_store()
    for j in tqdm(range(num_iterations)):
        loss = svi.step(data)
        losses.append(loss)

def truncate(alpha, centers, weights):
    threshold = alpha**-1 / 100.
    true_centers = centers[weights > threshold]
    true_weights = weights[weights > threshold] / torch.sum(weights[weights > threshold])
    return true_centers, true_weights

alpha = torch.tensor([0.1], device=device)
train(1000)

# We make a point-estimate of our model parameters using the posterior means of tau and phi for the centers and weights
Bayes_Centers_01, Bayes_Weights_01 = truncate(alpha, pyro.param("tau").detach(), torch.mean(pyro.param("phi").detach(), dim=0))

plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.scatter(data[:, 0].detach().cpu().numpy(), data[:, 1].detach().cpu().numpy(), color="blue")
plt.scatter(Bayes_Centers_01[:, 0].detach().cpu().numpy(), Bayes_Centers_01[:, 1].detach().cpu().numpy(), color="red")
plt.show()

I found a related post here.
I think that answers my second question (use of subsample instead of subsample_size).

I would appreciate any comments @fritzo and @ordabayev.

Hi @Nirvana

I would do it this way:

with pyro.plate("data", N, device=device) as idx:
        z = pyro.sample("z", Categorical(mix_weights(beta).unsqueeze(-2)))
        if num_particles > 1:
            pdx = torch.arange(num_particles, device=device)[:, None]
            mu_vindex = Vindex(mu)[pdx, z, :]
        else:
            mu_vindex = Vindex(mu)[z, :]
        pyro.sample("obs", MultivariateNormal(mu_vindex,
                                              torch.eye(2, device=device)), obs=data)

There is a special logic for num_particles == 1 since it works as if there is no num_particles dim

What do you mean by this? z’s are enumerated and are the same for all local instances.

You can do this:

    with pyro.plate("data", N, subsample_size=batch_size, device=device) as idx:
        ...
        pyro.sample("obs", MultivariateNormal(mu_vindex,
                                              torch.eye(2, device=device)), obs=data[idx])

    with pyro.plate("data", N, subsample_size=batch_size, device=device) as idx:
        z = pyro.sample("z", Categorical(phi[idx]))

That’s right, you can pass in subsample indices to subsample directly from the dataloader instead of using subsample_size.

Hi @ordabayev,
thanks! That was the solution I was looking for.

I am trying to understand the Ellipsis notation to get rid of if/else condition, e.g.,
instead of

        if num_particles > 1:
            pdx = torch.arange(num_particles, device=device)[:, None]
            mu_vindex = Vindex(mu)[pdx, z, :]
        else:
            mu_vindex = Vindex(mu)[z, :]

tried to just to use

mu_vindex = Vindex(mu)[..., z, :]

But this did not work (in case num_particles > 1).

So the additional dimension if num_particles specified other than 1 shall be handled with such an if/else check.

Hi @ordabayev

Blockquote
What do you mean by this? z’s are enumerated and are the same for all local instances.

According to my understanding, if I use DataLoader(shuffle=True) together with mini-batching, after each training epoch will the input data re-shuffled and batched. Since the z random variables are local, even if the z’s are enumerated (and the same), each ELBO gradient update is related to specific observation index x and its likelihood. Therefore the IndexedTensor dataset seems to be an appropriate solution together with subsample argument of the plate. In case I use DataLoader(shuffle=False) then I did not observe such an issue.

If I use subsample_size argument, the whole data shall be still loaded into GPU memory first and then indexed with idx to be sub-sampled later (as far as I understood). Is my understanding correct?

I hope what I meant is clear now. Otherwise we can try to clarify this issue with a working example code. I believe this is an important topic for practical application of Pyro (for newbies like me :slightly_smiling_face:).

Any additional comments/remarks related my third question mode-collapsing issues @ordabayev?

I think I figured it out :slight_smile: :

In order to be able to use ellipsis the dims that you are indexing need to be event dims. Do the following changes:

def model(data):
    ...
        mu = pyro.sample(
        "mu",
        MultivariateNormal(
            torch.zeros(2, device=device), 5 * torch.eye(2, device=device)
        )
        .expand([T])
        .to_event(1),
     )
     ...
         mu_vindex = Vindex(mu)[..., z, :]. # no if/else

def guide(data):
    # no plate here, use to_event instead
    q_mu = pyro.sample(
        "mu",
        MultivariateNormal(tau, torch.eye(2, device=device)).to_event(1),
    )
1 Like

That’s right. If you use the DataLoader and have it return indices, then pass the indices to the model/guide and use them as subsample key in the plate:

def model(data, idx):
    def plate("data", subsample=idx)
        pyro.sample("obs", ..., obs=data[idx])

This should work with both shuffle=True/False

1 Like

@martinjankowiak @fritzo any ideas?

Hi @ordabayev

I think I figured it out :slight_smile: :
In order to be able to use ellipsis the dims that you are indexing need to be event dims. Do the following changes: …

this is great, elegant and quite a bit mind bending at the same time :hugs:.
I needed to debug vindex function in indexing.py to understand your solution. The description of NEP21, which describes vectorized/inner indexing was also very helpful.

Just a short summary of my understanding to help others and me in the future; You practically reserve the batch dimensions only for the explicit/implicit plates (“data”/“num_particles”) and assign the previously reserved “mu_plate” batch-dimension to the event-dimension to be able index it with the z in addition to slice(None) or semicolon random variables in event space. In this way, […] describes the unknown batch dimension(s) reserved for plates in a flexible way (whether outer implicit “num_particles” plate is there or not).
A (practical?) difference/impact with respect to the previous (if/else) solution is that the previously independent random variable(s) (mu) within “mu_plate” context is now handled/behaved by Pyro as dependent random variable(s) in event space (which are T=6 independent cluster centers in reality).

1 Like

I think the example in cell [8] here can be updated according to this. It seems more elegant to use ellipsis for batch dims instead of indexing them with vdx[:, None] like in that example. (The example was added in Update enumeration tutorial by fritzo · Pull Request #2892 · pyro-ppl/pyro · GitHub following my suggestion in How to index into a batch dimension using an enumerated index? · Issue #2875 · pyro-ppl/pyro · GitHub)

@Nirvana would you be willing to contribute an updated version of the example?

I would be glad to contribute :grin:.

@ordabayev (assuming that my understanding above is correct) What is your opinion about the potential practical impact within Pyro to handle T=6 independent cluster centers in practice by mapping them into event space and handle them as dependent? At the end, my question is to decide between the coding-elegance and model’s adherence to reality/inference-efficiency.
I believe that I am on the safe side by assuming mu’s are dependent even though they are in fact conditionally independent.

I agree that it is safe and would also add that even though mu’s are independent, you cannot subsample mu_plate (there are fixed total number of categories) which is a distinguishing feature of pyro.plate over to_event.