Gaussian Mixture Model with partially observed data

Hi, I’m trying to extend this GMM implementation (https://github.com/mcdickenson/em-gaussian/blob/master/em-gaussian-pyro.py) in three ways:

  1. Non-diagonal covariance matrices
  2. Partially observed labels (i.e. for some x, observe the label for which mixture it came from)
  3. Upweight the observed labels relative to the unobserved ones.

I was able to implement (1) without issue. However, I am persistently running into initialization issues with my guide when implementing (2), despite the shapes of each tensor/sample being what I’d expect. I’ve provided some reproducible code below:

import matplotlib.pyplot as plt
import numpy as np
import pyro
import pyro.distributions as dist
import torch

from matplotlib.patches import Ellipse
from pyro import poutine
from pyro.contrib.autoguide import AutoDelta
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from torch.distributions import constraints


@config_enumerate(default='parallel')
@poutine.broadcast
def model(data, labels):
    # Global variables.
    weights = pyro.param('weights', torch.FloatTensor([0.5]), constraint=constraints.unit_interval)
    # scales = pyro.param('scales', torch.tensor([[[1., 0.], [0., 2.]], [[3., 0.], [0., 4.]]]), constraint=constraints.positive)
    scale_trils = pyro.param('scale_trils', 
                             torch.stack([torch.eye(2) for _ in range(K)]), 
                             constraint=constraints.lower_cholesky)
    locs = pyro.param('locs', torch.tensor([[1., 2.], [3., 4.]]))


    with pyro.iarange('data', data.size(0)):
        # Local variables.
        for i in pyro.irange('data_loop', data.size(0)):
            if labels[i] == -1:  # Unlabeled data
                assignment = pyro.sample(f'assignment_{i}', dist.Bernoulli(torch.ones(1) * weights)).to(torch.int64)
            else:  # Labeled data
                assignment = labels[i].long()

            pyro.sample(f'obs_{i}', dist.MultivariateNormal(locs[assignment], scale_tril=scale_trils[assignment]), obs=data[i])


@config_enumerate(default="parallel")
@poutine.broadcast
def full_guide(data, labels):
    with pyro.iarange('data', len(data)):
        unlabeled_mask = (labels == -1)

        # for obs mask to work, need to make sure labels are 0/1 even at unobserved sites
        labels[labels == -1] = 0

        assignment_probs = pyro.param('assignment_probs', torch.ones(K) / K,
                                      constraint=constraints.simplex)
        # Expanding assignment_probs to match the batch shape of the data plate
        assignment_probs = assignment_probs.expand(len(data), -1)

        pyro.sample('assignment', dist.Bernoulli(assignment_probs),
                    infer={"enumerate": "sequential"}, obs=labels, obs_mask=unlabeled_mask)





def initialize(data, labels):
    pyro.clear_param_store()

    optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
    elbo = TraceEnum_ELBO(max_iarange_nesting=1)
    svi = SVI(model, full_guide, optim, loss=elbo)

    # Initialize weights to uniform.
    pyro.param('auto_weights', 0.5 * torch.ones(K), constraint=constraints.simplex)

    # Assume half of the data variance is due to intra-component noise.
    var = (data.var() / 2).sqrt()
    pyro.param('auto_scale_trils', 
               torch.stack([torch.eye(2) for _ in range(K)]), 
               constraint=constraints.lower_cholesky)
    # pyro.param('auto_scale', torch.tensor([var]*4), constraint=constraints.positive)

    # Initialize means from a subsample of data.
    pyro.param('auto_locs', data[torch.multinomial(torch.ones(len(data)) / len(data), K)])

    loss = svi.loss(model, full_guide, data, labels)

    return loss, svi


def get_samples(labeled_fraction=0.2):
    num_samples = 100

    # 2 clusters
    # note that both covariance matrices are diagonal
    mu1 = torch.tensor([0., 5.])
    sig1 = torch.tensor([[2., 1.], [1., 2.]])

    mu2 = torch.tensor([5., 0.])
    sig2 = torch.tensor([[4., 3.], [3., 4.]])

    # generate samples
    dist1 = dist.MultivariateNormal(mu1, sig1)
    samples1 = [pyro.sample('samples1', dist1) for _ in range(num_samples)]

    dist2 = dist.MultivariateNormal(mu2, sig2)
    samples2 = [pyro.sample('samples2', dist2) for _ in range(num_samples)]

    data = torch.cat((torch.stack(samples1), torch.stack(samples2)))
    labels = torch.cat((torch.zeros(num_samples, dtype=torch.long), 
                        torch.ones(num_samples, dtype=torch.long)))
    
    # Randomly mask a fraction of labels
    num_labeled = int(len(data) * labeled_fraction)
    labeled_indices = torch.randperm(len(data))[:num_labeled]
    labels_masked = torch.full(labels.shape, -1, dtype=torch.long)  # Ensure Long type
    labels_masked[labeled_indices] = labels[labeled_indices]

    return data, labels_masked

    # return data


def plot(data, mus=None, scale_trils=None, sigmas=None, colors='black', figname='fig.png'):
    # Create figure
    fig = plt.figure()

    # Plot data
    x = data[:, 0]
    y = data[:, 1]
    plt.scatter(x, y, 24, c=colors)

    # Plot cluster centers
    if mus is not None:
        x = [float(m[0]) for m in mus]
        y = [float(m[1]) for m in mus]
        plt.scatter(x, y, 99, c='red')

    # Plot ellipses for each cluster
    if sigmas is not None:
        if scale_trils is not None:
              # Reconstruct the full covariance matrix
            for sig_ix in range(K):
                cov = scale_trils[sig_ix] @ scale_trils[sig_ix].T
                ax = fig.gca()
                # cov = np.array(sigmas[sig_ix])
                lam, v = np.linalg.eig(cov)
                lam = np.sqrt(lam)
                ell = Ellipse(xy=(x[sig_ix], y[sig_ix]),
                            width=lam[0]*4, height=lam[1]*4,
                            angle=np.rad2deg(np.arccos(v[0, 0])),
                            color='blue')
                ell.set_facecolor('none')
                ax.add_artist(ell)

    # Save figure
    fig.savefig(figname)


if __name__ == "__main__":
    pyro.enable_validation(True)
    pyro.set_rng_seed(42)

    # Create our model with a fixed number of components
    K = 2

    data, labels = get_samples()

    global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs', 'scales']))
    global_guide = config_enumerate(global_guide, 'parallel')
    print("initializing...")
    _, svi = initialize(data, labels)
    print("initialized")

    true_colors = [0] * 100 + [1] * 100
    plot(data, colors=true_colors, figname='pyro_init.png')

    for i in range(501):
        print(i)
        # scale_trils = pyro.param('scale_trils')
        svi.step(data, labels)

        if i % 50 == 0:
            locs = pyro.param('locs')
            scale_trils = pyro.param('scale_trils')
            weights = pyro.param('weights')
            assignment_probs = pyro.param('assignment_probs')

            print("locs: {}".format(locs))
            print("scales: {}".format(scale_trils))
            print('weights = {}'.format(weights))
            print('assignments: {}'.format(assignment_probs))

            # todo plot data and estimates
            assignments = np.uint8(np.round(assignment_probs.data))
            plot(data, locs.data, scale_trils.data, assignments, figname='pyro_iteration{}.png'.format(i))

You can compare this code to the one in the GitHub link as well — the only changes are the non-diagonal covariance matrices and the partially observed labels.