I am trying to implement a GMM class that I can substitute for sklearn’s implementation. This is my code:
class GMM:
def __init__(self, n_components, max_iter=1000, tol=1e-3, learning_rate=0.05, verbose=1):
self.n_components = n_components
self.max_iter = max_iter
self.tol = tol
self.learning_rate = learning_rate
self.verbose = verbose
self.means = None
self.scales = None
self.weights = None
self.probabilities_ = None
def model(self, data):
n_datapoints, dim = data.shape
with pyro.plate('mixture_components', self.n_components):
means = pyro.sample('means', dist.Normal(torch.zeros(dim), torch.ones(dim)).to_event(1))
scales = pyro.sample('scales', dist.LogNormal(torch.zeros(dim), torch.ones(dim)).to_event(1))
weights = pyro.sample('weights', dist.Dirichlet(torch.ones(self.n_components)))
with pyro.plate('data', n_datapoints):
assignment = pyro.sample('assignment', dist.Categorical(weights))
pyro.sample('obs', dist.Normal(means[assignment], scales[assignment]).to_event(1), obs=data)
def guide(self, data):
n_datapoints, dim = data.shape
with pyro.plate('mixture_components', self.n_components):
means_loc = pyro.param('means_loc', torch.zeros((self.n_components, dim)))
means_scale = pyro.param('means_scale', torch.ones((self.n_components, dim)), constraint=constraints.positive)
scales_loc = pyro.param('scales_loc', torch.zeros((self.n_components, dim)))
scales_scale = pyro.param('scales_scale', torch.ones((self.n_components, dim)), constraint=constraints.positive)
pyro.sample('means', dist.Normal(means_loc, means_scale).to_event(1))
pyro.sample('scales', dist.LogNormal(scales_loc, scales_scale).to_event(1))
weights_dir = pyro.param('weights_dir', torch.ones(self.n_components), constraint=constraints.positive)
pyro.sample('weights', dist.Dirichlet(weights_dir))
with pyro.plate('data', n_datapoints):
pyro.sample('assignment', dist.Categorical(weights_dir))
def fit(self, data):
svi = SVI(self.model, self.guide, Adam({'lr': self.learning_rate}), Trace_ELBO())
previous_loss = float('inf')
for step in range(self.max_iter):
current_loss = svi.step(data)
loss_difference = previous_loss - current_loss
if self.verbose:
logging.info(f"Step {step}: Loss = {current_loss}, ΔLoss = {loss_difference}")
# Check for convergence
if abs(loss_difference) < self.tol:
print(f"Converged at step {step}")
break
previous_loss = current_loss
self.means = pyro.param('means_loc').detach()
self.scales = pyro.param('scales_loc').detach().exp() # from log scale
self.weights = pyro.param('weights_dir').detach()
def predict(self, data):
mixture_dist = dist.Categorical(self.weights)
component_dists = dist.Normal(self.means, self.scales)
with torch.no_grad():
assignments = [torch.argmax(mixture_dist.log_prob(component_dists.log_prob(d)).sum(-1), -1) for d in data]
return torch.stack(assignments)
def calculate_probabilities(self, data):
mixture_dist = dist.Categorical(self.weights)
component_dists = dist.Normal(self.means, self.scales)
with torch.no_grad():
probs = [mixture_dist.log_prob(component_dists.log_prob(d)).exp().sum(-1) for d in data]
self.probabilities_ = torch.stack(probs)
When I run it, however, the estimates aren’t even close to sensible and the loss fluctuates like crazy, irrespective of the learning rate. Here is some code to test things:
def test_data():
import torch
# Define true parameters for a 2-component GMM
true_means = torch.tensor([[2.0, 2.0], [-2.0, -2.0]])
true_scales = torch.tensor([[0.5, 0.5], [0.5, 0.5]])
true_weights = torch.tensor([0.5, 0.5])
# Generate synthetic data
n_data = 1000
assignment = torch.multinomial(true_weights, n_data, replacement=True)
data = torch.stack([torch.normal(true_means[i], true_scales[i]) for i in assignment])
return data, true_means, true_scales, true_weights, assignment
def test_GMM(learning_rate=0.05, max_iter=1_000):
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
import numpy as np
data, true_means, true_scales, true_weights, assignment = test_data()
# Fit the GMM model to the synthetic data
gmm = GMM(n_components=2, max_iter=max_iter, learning_rate=learning_rate)
gmm.fit(data)
# Fit Scikit-Learn's GMM model to the same data
sklearn_gmm = GaussianMixture(n_components=2, max_iter=500, tol=1e-3)
sklearn_gmm.fit(data.numpy())
# Compare the inferred parameters with the true parameters
print("True Means:", true_means)
print("Pyro Inferred Means:", gmm.means)
print("Sklearn Inferred Means:", torch.tensor(sklearn_gmm.means_))
print("True Scales:", true_scales)
print("Pyro Inferred Scales:", gmm.scales)
print("Sklearn Inferred Scales:", torch.sqrt(torch.tensor(sklearn_gmm.covariances_).diag_embed().sum(-1)))
print("True Weights:", true_weights)
print("Pyro Inferred Weights:", gmm.weights)
print("Sklearn Inferred Weights:", torch.tensor(sklearn_gmm.weights_))
# Visualize the results
plt.scatter(data[:, 0], data[:, 1], c=assignment, cmap='viridis', alpha=0.5)
plt.scatter(true_means[:, 0], true_means[:, 1], marker='x', color='red', label='True Means')
plt.scatter(gmm.means[:, 0], gmm.means[:, 1], marker='o', color='blue', label='Pyro Inferred Means')
plt.scatter(sklearn_gmm.means_[:, 0], sklearn_gmm.means_[:, 1], marker='^', color='green', label='Sklearn Inferred Means')
plt.legend()
plt.show()
It seems clear that there is something wrong with my code and that it’s not just an initialization issue, but I’m completely new to pyro and just don’t see what’s wrong, even after going through the GMM tutorial.