GMM estimates completely off

I am trying to implement a GMM class that I can substitute for sklearn’s implementation. This is my code:

class GMM:
    def __init__(self, n_components, max_iter=1000, tol=1e-3, learning_rate=0.05, verbose=1):
        self.n_components = n_components
        self.max_iter = max_iter
        self.tol = tol
        self.learning_rate = learning_rate
        self.verbose = verbose
        self.means = None
        self.scales = None
        self.weights = None
        self.probabilities_ = None


    def model(self, data):
        n_datapoints, dim = data.shape

        with pyro.plate('mixture_components', self.n_components):
            means = pyro.sample('means', dist.Normal(torch.zeros(dim), torch.ones(dim)).to_event(1))
            scales = pyro.sample('scales', dist.LogNormal(torch.zeros(dim), torch.ones(dim)).to_event(1))

        weights = pyro.sample('weights', dist.Dirichlet(torch.ones(self.n_components)))

        with pyro.plate('data', n_datapoints):
            assignment = pyro.sample('assignment', dist.Categorical(weights))

            pyro.sample('obs', dist.Normal(means[assignment], scales[assignment]).to_event(1), obs=data)


    def guide(self, data):
        n_datapoints, dim = data.shape

        with pyro.plate('mixture_components', self.n_components):
            means_loc = pyro.param('means_loc', torch.zeros((self.n_components, dim)))
            means_scale = pyro.param('means_scale', torch.ones((self.n_components, dim)), constraint=constraints.positive)
            scales_loc = pyro.param('scales_loc', torch.zeros((self.n_components, dim)))
            scales_scale = pyro.param('scales_scale', torch.ones((self.n_components, dim)), constraint=constraints.positive)

            pyro.sample('means', dist.Normal(means_loc, means_scale).to_event(1))
            pyro.sample('scales', dist.LogNormal(scales_loc, scales_scale).to_event(1))

        weights_dir = pyro.param('weights_dir', torch.ones(self.n_components), constraint=constraints.positive)
        pyro.sample('weights', dist.Dirichlet(weights_dir))

        with pyro.plate('data', n_datapoints):
            pyro.sample('assignment', dist.Categorical(weights_dir))


    def fit(self, data):
        svi = SVI(self.model, self.guide, Adam({'lr': self.learning_rate}), Trace_ELBO())
        previous_loss = float('inf')
        for step in range(self.max_iter):
            current_loss = svi.step(data)
            loss_difference = previous_loss - current_loss

            if self.verbose:
                logging.info(f"Step {step}: Loss = {current_loss}, ΔLoss = {loss_difference}")

            # Check for convergence
            if abs(loss_difference) < self.tol:
                print(f"Converged at step {step}")
                break

            previous_loss = current_loss

        self.means = pyro.param('means_loc').detach()
        self.scales = pyro.param('scales_loc').detach().exp()  # from log scale
        self.weights = pyro.param('weights_dir').detach()


    def predict(self, data):
        mixture_dist = dist.Categorical(self.weights)
        component_dists = dist.Normal(self.means, self.scales)

        with torch.no_grad():
            assignments = [torch.argmax(mixture_dist.log_prob(component_dists.log_prob(d)).sum(-1), -1) for d in data]
            return torch.stack(assignments)


    def calculate_probabilities(self, data):
        mixture_dist = dist.Categorical(self.weights)
        component_dists = dist.Normal(self.means, self.scales)

        with torch.no_grad():
            probs = [mixture_dist.log_prob(component_dists.log_prob(d)).exp().sum(-1) for d in data]
            self.probabilities_ = torch.stack(probs)

When I run it, however, the estimates aren’t even close to sensible and the loss fluctuates like crazy, irrespective of the learning rate. Here is some code to test things:

def test_data():
    import torch

    # Define true parameters for a 2-component GMM
    true_means = torch.tensor([[2.0, 2.0], [-2.0, -2.0]])
    true_scales = torch.tensor([[0.5, 0.5], [0.5, 0.5]])
    true_weights = torch.tensor([0.5, 0.5])

    # Generate synthetic data
    n_data = 1000
    assignment = torch.multinomial(true_weights, n_data, replacement=True)
    data = torch.stack([torch.normal(true_means[i], true_scales[i]) for i in assignment])

    return data, true_means, true_scales, true_weights, assignment



def test_GMM(learning_rate=0.05, max_iter=1_000):
    import matplotlib.pyplot as plt
    from sklearn.mixture import GaussianMixture
    import numpy as np

    data, true_means, true_scales, true_weights, assignment = test_data()
    # Fit the GMM model to the synthetic data
    gmm = GMM(n_components=2, max_iter=max_iter, learning_rate=learning_rate)
    gmm.fit(data)

    # Fit Scikit-Learn's GMM model to the same data
    sklearn_gmm = GaussianMixture(n_components=2, max_iter=500, tol=1e-3)
    sklearn_gmm.fit(data.numpy())

    # Compare the inferred parameters with the true parameters
    print("True Means:", true_means)
    print("Pyro Inferred Means:", gmm.means)
    print("Sklearn Inferred Means:", torch.tensor(sklearn_gmm.means_))

    print("True Scales:", true_scales)
    print("Pyro Inferred Scales:", gmm.scales)
    print("Sklearn Inferred Scales:", torch.sqrt(torch.tensor(sklearn_gmm.covariances_).diag_embed().sum(-1)))

    print("True Weights:", true_weights)
    print("Pyro Inferred Weights:", gmm.weights)
    print("Sklearn Inferred Weights:", torch.tensor(sklearn_gmm.weights_))

    # Visualize the results
    plt.scatter(data[:, 0], data[:, 1], c=assignment, cmap='viridis', alpha=0.5)
    plt.scatter(true_means[:, 0], true_means[:, 1], marker='x', color='red', label='True Means')
    plt.scatter(gmm.means[:, 0], gmm.means[:, 1], marker='o', color='blue', label='Pyro Inferred Means')
    plt.scatter(sklearn_gmm.means_[:, 0], sklearn_gmm.means_[:, 1], marker='^', color='green', label='Sklearn Inferred Means')
    plt.legend()
    plt.show()

It seems clear that there is something wrong with my code and that it’s not just an initialization issue, but I’m completely new to pyro and just don’t see what’s wrong, even after going through the GMM tutorial.

Have you tried using TraceEnum_ELBO as is shown in the GMM tutorial? It reduces the noise in the gradients and helps to converge better.

More critically, you have to use assignment_probs instead of weights_dir for the categorical probs in the guide pyro.sample("assignment", dist.Categorical(assignment_probs)) as in the tutorial.

Thanks for the reply. I don’t quite understand this though. Why wouldn’t sampling global probabilities from a dirichlet distribution work? Modeling the individual assignment probabilities of each datapoint directly seems odd, I’m also not sure how this goes together with the model() then, because they have different parameters with different shapes now (“weights” vs “assignment_probs”).

I have tried using TraceEnum_ELBO as well, but it didn’t help. I don’t think it’s a stability issue either, this is what the first couple of losses look like:

07.02.2024 14:17:20 Step 1: Loss = 22012.927734375, ΔLoss = 21013.380859375
07.02.2024 14:17:20 Step 2: Loss = 23817.86328125, ΔLoss = -1804.935546875
07.02.2024 14:17:20 Step 3: Loss = 14505.9384765625, ΔLoss = 9311.9248046875
07.02.2024 14:17:20 Step 4: Loss = 151988.09375, ΔLoss = -137482.1552734375
07.02.2024 14:17:20 Step 5: Loss = 6751.86279296875, ΔLoss = 145236.23095703125
07.02.2024 14:17:20 Step 6: Loss = 20873.6796875, ΔLoss = -14121.81689453125
07.02.2024 14:17:20 Step 7: Loss = 11743.2783203125, ΔLoss = 9130.4013671875
07.02.2024 14:17:20 Step 8: Loss = 6044.62158203125, ΔLoss = 5698.65673828125
07.02.2024 14:17:20 Step 9: Loss = 28566.6953125, ΔLoss = -22522.07373046875
07.02.2024 14:17:20 Step 10: Loss = 10670.447265625, ΔLoss = 17896.248046875
07.02.2024 14:17:20 Step 11: Loss = 11903.3818359375, ΔLoss = -1232.9345703125
07.02.2024 14:17:20 Step 12: Loss = 374360.65625, ΔLoss = -362457.2744140625
07.02.2024 14:17:20 Step 13: Loss = 34003.15625, ΔLoss = 340357.5

Why is it odd? If you have 100 data points then each belongs to one of the two clusters. Thus you need 100 assignment_probs, not just global weight_dirs. If you right down your equations for probabilities it might be clearer:

The model
p(\mu)p(\sigma)p(w)\prod_i p(\mathrm{assignment}_i|w)p(\mathrm{obs}_i|\mu,\sigma,\mathrm{assignment}_i)

The guide (which approximates posterior distribution):
q(\mu)q(\sigma)q(w)\prod_i q(\mathrm{assignment}_i)

Here q(\mathrm{assignment}_i) is the (Categorical) probability of the data point i belonging to clusters 1 and 2.