LKJ Priors in Dirichlet Process Clustering

I extending the Dirichlet Process example here to support placing priors on the covariance matrices of the clusters.

Since there are no Wishart priors implemented yet, I’ve gone with the LKJ priors, I am struggling however to achieve good clustering performance. Any feedback would be greatly appreciated.

Here is an example

Data generation:

import torch
from torch.distributions import Gamma

import torch.nn.functional as F

import matplotlib.pyplot as plt
from tqdm import tqdm

from pyro.distributions import *

import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, Predictive

assert pyro.__version__.startswith('1')
pyro.enable_validation(True)       # can help with debugging
pyro.set_rng_seed(0)

data = torch.cat((MultivariateNormal(-2 * torch.ones(2), 0.1 * torch.eye(2)).sample([50]),
                  MultivariateNormal(2 * torch.ones(2), 0.1 * torch.eye(2)).sample([50]),
                  MultivariateNormal(torch.tensor([0., 0.]), 0.1 * torch.eye(2)).sample([50])))

N = data.shape[0]

Model and guide:

def model(data, **kwargs):
    with pyro.plate("beta_plate", T - 1):
        beta = pyro.sample("beta", Beta(1, alpha))

    zeta = 1. * torch.ones(T, 2)
    delta = 2. * torch.ones(T, 2)
    with pyro.plate("prec_plate", T):
        prec = pyro.sample("prec", Gamma(zeta, delta).to_event(1))

    corr_chol = torch.zeros(T, 2, 2)
    for t in range(T):
        corr_chol[t, ...] = pyro.sample("corr_chol_{}".format(t), LKJCorrCholesky(d=2, eta=torch.ones(1,)))

    with pyro.plate("mu_plate", T):
        _std = torch.sqrt(1. / prec)
        sigma_chol = torch.bmm(torch.diag_embed(_std), corr_chol)
        mu = pyro.sample("mu", MultivariateNormal(torch.zeros(T, 2), scale_tril=sigma_chol))

    with pyro.plate("data", N):
        z = pyro.sample("z", Categorical(mix_weights(beta)))
        pyro.sample("obs", MultivariateNormal(mu[z], scale_tril=sigma_chol[z]), obs=data)


def guide(data, **kwargs):
    gamma = pyro.param('gamma', alpha * torch.ones(T - 1,), constraint=constraints.positive)

    zeta = pyro.param('zeta', lambda: Uniform(1., 2.).sample([2]),  constraint=constraints.positive)
    delta = pyro.param('delta', lambda: Uniform(1., 2.).sample([2]), constraint=constraints.positive)

    psi = pyro.param('psi', lambda: Uniform(0.5, 1e0).sample([T]), constraint=constraints.positive)

    tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(2), 10 * torch.eye(2)).sample([T]))
    pi = pyro.param('pi', torch.ones(N, T) / T, constraint=constraints.simplex)

    with pyro.plate("beta_plate", T - 1):
        q_beta = pyro.sample("beta", Beta(torch.ones(T - 1), gamma))

    with pyro.plate("prec_plate", T):
        q_prec = pyro.sample("prec", Gamma(zeta, delta).to_event(1))

    q_corr_chol = torch.zeros(T, 2, 2)
    for t in range(T):
        q_corr_chol[t, ...] = pyro.sample("corr_chol_{}".format(t), LKJCorrCholesky(d=2, eta=psi[t]))

    with pyro.plate("mu_plate", T):
        _q_std = torch.sqrt(1. / q_prec)
        q_sigma_chol = torch.bmm(torch.diag_embed(_q_std), q_corr_chol)
        q_mu = pyro.sample("mu", MultivariateNormal(tau, scale_tril=q_sigma_chol))

    with pyro.plate("data", N):
        z = pyro.sample("z", Categorical(pi))

Run:

T = 3

optim = Adam({"lr": 0.001})
svi = SVI(model, guide, optim, loss=Trace_ELBO(num_particles=3))


def train(num_iterations):
    losses = []
    pyro.clear_param_store()

    fig = plt.figure(figsize=(5, 5))
    for j in tqdm(range(num_iterations)):
        loss = svi.step(data)
        losses.append(loss)

        if (j % 100) == 0:
            centers, covars = marginal(guide, num_samples=500)
            animate(fig.gca(), centers, covars)
            plt.draw()
            plt.axis('equal')
            plt.pause(0.001)
            plt.clf()

    return losses


def truncate(alpha, centers, perc, corrs, weights):
    threshold = alpha**-1 / 100.
    true_centers = centers[weights > threshold]

    prec = perc.view(T, 2)
    true_prec = prec[weights > threshold]

    true_corrs = corrs[weights > threshold, ...]

    _stds = torch.sqrt(1. / true_prec.view(-1, 2))
    _sigmas = torch.bmm(torch.diag_embed(_stds), true_corrs)

    true_sigmas = torch.zeros(len(_sigmas), 2, 2)
    for n in range(len(_sigmas)):
        true_sigmas[n, ...] = torch.mm(_sigmas[n, ...], _sigmas[n, ...].T)

    true_weights = weights[weights > threshold] / torch.sum(weights[weights > threshold])
    return true_centers, true_sigmas, true_weights


def marginal(guide, num_samples=25):
    posterior_predictive = Predictive(guide, num_samples=num_samples)
    posterior_samples = posterior_predictive.forward(data)

    mu_mean = posterior_samples['mu'].detach().mean(dim=0)
    prec_mean = posterior_samples['prec'].detach().mean(dim=0)

    corr_mean = torch.zeros(T, 2, 2)
    for t in range(T):
        corr_mean[t, ...] = posterior_samples['corr_chol_{}'.format(t)].detach().mean(dim=0)

    beta_mean = posterior_samples['beta'].detach().mean(dim=0)
    weights_mean = mix_weights(beta_mean)

    centers, sigmas, _ = truncate(alpha, mu_mean, prec_mean, corr_mean, weights_mean)

    return centers, sigmas


def animate(axes, centers, covars):
    plt.scatter(data[:, 0], data[:, 1], color="blue", marker="+")

    from math import pi
    t = torch.arange(0, 2 * pi, 0.01)
    circle = torch.stack([torch.sin(t), torch.cos(t)], dim=0)

    axes.scatter(centers[:, 0], centers[:, 1], color="red")
    for n in range(len(covars)):
        ellipse = torch.mm(torch.cholesky(covars[n, ...]), circle)
        axes.plot(ellipse[0, :] + centers[n, 0], ellipse[1, :] + centers[n, 1],
                  linestyle='-', linewidth=2, color='g', alpha=1.)


alpha = 0.1
elbo = train(25000)

plt.figure()
plt.plot(elbo)

Ok so I had to tweak the model a couple times. It was a great way to learn how to use pyro.

The bottle neck is mainly in the sampling of the posterior. I had to choose a large number of samples in order to have a good approximation of the ELBO and this seems to be very slow and scales badly.

Unfortunately I don’t think such models are very practical in pyro yet. The amount of computation is unbelievably high. There may be a way to improve it through enumeration but I coudn’t figure it out yet. I hope this will improve in the future.

Here is the final product with some primitive animation or find it in one piece here

Data generation:

import torch
from torch.distributions import Gamma

import torch.nn.functional as F

import matplotlib.pyplot as plt
from tqdm import tqdm

from pyro.distributions import *

import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, Predictive

assert pyro.__version__.startswith('1')
pyro.enable_validation(True)
pyro.set_rng_seed(1337)

# device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")

data = torch.cat((MultivariateNormal(-2 * torch.ones(2), 0.1 * torch.eye(2)).sample([25]),
                  MultivariateNormal(2 * torch.ones(2), 0.1 * torch.eye(2)).sample([25]),
                  MultivariateNormal(torch.tensor([0., 0.]), 0.1 * torch.eye(2)).sample([25])))

data = data.to(device)

N = data.shape[0]
D = data.shape[1]

Model and guide:

def mix_weights(beta):
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)


def model(data, **kwargs):
    with pyro.plate("beta_plate", T - 1):
        beta = pyro.sample("beta", Beta(1, alpha))

    zeta = 2. * torch.ones(T * D, device=device)
    delta = 2. * torch.ones(T * D, device=device)
    with pyro.plate("prec_plate", T * D):
        prec = pyro.sample("prec", Gamma(zeta, delta))

    corr_chol = torch.zeros(T, D, D, device=device)
    for t in pyro.plate("corr_chol_plate", T):
        corr_chol[t, ...] = pyro.sample("corr_chol_{}".format(t), LKJCorrCholesky(d=D, eta=torch.ones(1, device=device)))

    with pyro.plate("mu_plate", T):
        _std = torch.sqrt(1. / prec.view(-1, D))
        sigma_chol = torch.bmm(torch.diag_embed(_std), corr_chol)
        mu = pyro.sample("mu", MultivariateNormal(torch.zeros(T, D, device=device), scale_tril=sigma_chol))

    with pyro.plate("data", N):
        z = pyro.sample("z", Categorical(mix_weights(beta)))
        pyro.sample("obs", MultivariateNormal(mu[z], scale_tril=sigma_chol[z]), obs=data)


def guide(data, **kwargs):
    gamma = pyro.param('gamma', alpha * torch.ones(T - 1, device=device), constraint=constraints.positive)

    zeta = pyro.param('zeta', lambda: Uniform(1., 2.).sample([T * D]).to(device),  constraint=constraints.positive)
    delta = pyro.param('delta', lambda: Uniform(1., 2.).sample([T * D]).to(device), constraint=constraints.positive)

    psi = pyro.param('psi', lambda: Uniform(1., 2.).sample([T]).to(device), constraint=constraints.positive)

    tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(D), 10 * torch.eye(2)).sample([T]).to(device))
    pi = pyro.param('pi', torch.ones(N, T, device=device) / T, constraint=constraints.simplex)

    with pyro.plate("beta_plate", T - 1):
        q_beta = pyro.sample("beta", Beta(torch.ones(T - 1, device=device), gamma))

    with pyro.plate("prec_plate", T * D):
        q_prec = pyro.sample("prec", Gamma(zeta, delta))

    q_corr_chol = torch.zeros(T, D, D, device=device)
    for t in pyro.plate("corr_chol_plate", T):
        q_corr_chol[t, ...] = pyro.sample("corr_chol_{}".format(t), LKJCorrCholesky(d=D, eta=psi[t]))

    with pyro.plate("mu_plate", T):
        _q_std = torch.sqrt(1. / q_prec.view(-1, D))
        q_sigma_chol = torch.bmm(torch.diag_embed(_q_std), q_corr_chol)
        q_mu = pyro.sample("mu", MultivariateNormal(tau, scale_tril=q_sigma_chol))

    with pyro.plate("data", N):
        z = pyro.sample("z", Categorical(pi))

Run:

T = 5

optim = Adam({"lr": 0.01})
svi = SVI(model, guide, optim, loss=Trace_ELBO(num_particles=35))


def train(num_iterations):
    losses = []
    pyro.clear_param_store()

    # fig = plt.figure(figsize=(5, 5))

    for j in tqdm(range(num_iterations)):
        loss = svi.step(data)
        losses.append(loss)

        # if (j % 100) == 0:
        #     centers, covars = marginal(guide, num_samples=250)
        #     animate(fig.gca(), centers, covars)
        #     plt.draw()
        #     plt.axis('equal')
        #     plt.pause(0.001)
        #     plt.clf()

    return losses


def truncate(alpha, centers, perc, corrs, weights):
    threshold = alpha**-1 / 100.
    true_centers = centers[weights > threshold]

    prec = perc.view(T, D)
    true_prec = prec[weights > threshold]

    true_corrs = corrs[weights > threshold, ...]

    _stds = torch.sqrt(1. / true_prec.view(-1, D))
    _sigmas = torch.bmm(torch.diag_embed(_stds), true_corrs)

    true_sigmas = torch.zeros(len(_sigmas), D, D)
    for n in range(len(_sigmas)):
        true_sigmas[n, ...] = torch.mm(_sigmas[n, ...], _sigmas[n, ...].T)

    true_weights = weights[weights > threshold] / torch.sum(weights[weights > threshold])
    return true_centers, true_sigmas, true_weights


def marginal(guide, num_samples=25):
    posterior_predictive = Predictive(guide, num_samples=num_samples)
    posterior_samples = posterior_predictive.forward(data)

    mu_mean = posterior_samples['mu'].detach().mean(dim=0)
    prec_mean = posterior_samples['prec'].detach().mean(dim=0)

    corr_mean = torch.zeros(T, D, D)
    for t in range(T):
        corr_mean[t, ...] = posterior_samples['corr_chol_{}'.format(t)].detach().mean(dim=0)

    beta_mean = posterior_samples['beta'].detach().mean(dim=0)
    weights_mean = mix_weights(beta_mean)

    centers, sigmas, _ = truncate(alpha, mu_mean, prec_mean, corr_mean, weights_mean)

    return centers, sigmas


def animate(axes, centers, covars):
    plt.scatter(data[:, 0], data[:, 1], color="blue", marker="+")

    from math import pi
    t = torch.arange(0, 2 * pi, 0.01)
    circle = torch.stack([torch.sin(t), torch.cos(t)], dim=0)

    axes.scatter(centers[:, 0], centers[:, 1], color="red")
    for n in range(len(covars)):
        ellipse = torch.mm(torch.cholesky(covars[n, ...]), circle)
        axes.plot(ellipse[0, :] + centers[n, 0], ellipse[1, :] + centers[n, 1],
                  linestyle='-', linewidth=2, color='g', alpha=1.)


alpha = 0.1 * torch.ones(1, device=device)
elbo = train(5000)

# plt.figure()
# plt.plot(elbo)
1 Like

@hanyas have you tried doing MAP wrt corr_chol_{}? because sampling correlation matrices in the context of variational inference may not work very well. in particular the LKJ distribution is not “reparameterizable” and so it necessarily leads to high-variance gradients (because the ELBO estimator used so-called score function gradients). you might get much more reasonable performance if you “demote” corr_chol_{} to a point estimate

@martinjankowiak Thanks for the hint, although it kinda defeats the ultimate purpose I am pursuing.

I tried something like the following, but I keep getting invalid scale_tril values, which I am guess is due to the Delta distribution I am defining.

Let me know if there is a proper way of defining the MAP. It is a bit under-documented.

auto_corr_chol = torch.zeros(T, D, D, device=device)
q_corr_chol = torch.zeros(T, D, D, device=device)
for t in pyro.plate("corr_chol_plate", T):
    # _psi = Uniform(1., 2.).sample()
    auto_corr_chol[t, ...] = pyro.param("auto_corr_chol_{}".format(t), torch.eye(D).to(device))
    q_corr_chol[t, ...] = pyro.sample("corr_chol_{}".format(t), Delta(auto_corr_chol[t, ...]).to_event(1).to_event(1))

@hanyas can you post your error? you may just be running into optimization issues (too high learning rate, etc)

I suspected an issue with the step size so i changed it but the error remained. Here I am using a learning rate of 0.001 with Adam

  0%|          | 1/5000 [00:00<39:54,  2.09it/s]
Traceback (most recent call last):
  File "/home/hany/.miniconda3/envs/hbreps/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 165, in __call__
ret = self.fn(*args, **kwargs)
  File "/home/hany/phd/repos/playground/misc/pyro/dpgmm.py", line 87, in guide
q_mu = pyro.sample("mu", MultivariateNormal(tau, scale_tril=q_sigma_chol))
  File "/home/hany/.miniconda3/envs/hbreps/lib/python3.7/site-packages/torch/distributions/multivariate_normal.py", line 144, in __init__
super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args)
  File "/home/hany/.miniconda3/envs/hbreps/lib/python3.7/site-packages/torch/distributions/distribution.py", line 36, in __init__
raise ValueError("The parameter {} has invalid values".format(param))
ValueError: The parameter scale_tril has invalid values
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "/opt/pycharm-professional/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/opt/pycharm-professional/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/hany/phd/repos/playground/misc/pyro/dpgmm.py", line 174, in <module>
elbo = train(5000)
  File "/home/hany/phd/repos/playground/misc/pyro/dpgmm.py", line 106, in train
loss = svi.step(data)
  File "/home/hany/.miniconda3/envs/hbreps/lib/python3.7/site-packages/pyro/infer/svi.py", line 128, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/home/hany/.miniconda3/envs/hbreps/lib/python3.7/site-packages/pyro/infer/trace_elbo.py", line 126, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
  File "/home/hany/.miniconda3/envs/hbreps/lib/python3.7/site-packages/pyro/infer/elbo.py", line 170, in _get_traces
yield self._get_trace(model, guide, args, kwargs)
  File "/home/hany/.miniconda3/envs/hbreps/lib/python3.7/site-packages/pyro/infer/trace_elbo.py", line 53, in _get_trace
"flat", self.max_plate_nesting, model, guide, args, kwargs)
  File "/home/hany/.miniconda3/envs/hbreps/lib/python3.7/site-packages/pyro/infer/enum.py", line 44, in get_importance_trace
guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(*args, **kwargs)
  File "/home/hany/.miniconda3/envs/hbreps/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 185, in get_trace
self(*args, **kwargs)
  File "/home/hany/.miniconda3/envs/hbreps/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 169, in __call__
raise exc_type(u"{}\n{}".format(exc_value, shapes)).with_traceback(traceback)
  File "/home/hany/.miniconda3/envs/hbreps/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 165, in __call__
ret = self.fn(*args, **kwargs)
  File "/home/hany/phd/repos/playground/misc/pyro/dpgmm.py", line 87, in guide
q_mu = pyro.sample("mu", MultivariateNormal(tau, scale_tril=q_sigma_chol))
  File "/home/hany/.miniconda3/envs/hbreps/lib/python3.7/site-packages/torch/distributions/multivariate_normal.py", line 144, in __init__
super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args)
  File "/home/hany/.miniconda3/envs/hbreps/lib/python3.7/site-packages/torch/distributions/distribution.py", line 36, in __init__
raise ValueError("The parameter {} has invalid values".format(param))
ValueError: The parameter scale_tril has invalid values
   Trace Shapes:          
    Param Sites:          
           gamma     4    
            zeta    10    
           delta    10    
             tau  5  2    
              pi 75  5    
auto_corr_chol_0  2  2    
auto_corr_chol_1  2  2    
auto_corr_chol_2  2  2    
auto_corr_chol_3  2  2    
auto_corr_chol_4  2  2    
   Sample Sites:          
 beta_plate dist     |    
           value  4  |    
       beta dist  4  |    
           value  4  |    
 prec_plate dist     |    
           value 10  |    
       prec dist 10  |    
           value 10  |    
corr_chol_plate dist     |    
           value  5  |    
corr_chol_0 dist     | 2 2
           value     | 2 2
corr_chol_1 dist     | 2 2
           value     | 2 2
corr_chol_2 dist     | 2 2
           value     | 2 2
corr_chol_3 dist     | 2 2
           value     | 2 2
corr_chol_4 dist     | 2 2
           value     | 2 2
   mu_plate dist     |    
           value  5  |

i believe you’re missing a constraint in your param statement

from pyro.distributions.constraints import corr_cholesky_constraint
pyro.param("auto_corr_chol_{}" ..., constraint=corr_cholesky_constraint)

also note that you should be able to vectorize your corr_chol_plate plate

1 Like