Question regarding Gaussian mixture model (GMM) using EM algorithm

Hello everyone! I have read the GMM tutorial on the website, and I wanted to create a naive version of GMM where the mixture weights, means, covariance matrices are treated as learnable parameters instead of latent variables; in other words, no prior distributions are assumed for them. The model assumptions and the E-step are shown in the picture attached.

I think the major problem I had was I don’t know how to use the same set of parameters for model and guide. In pyro, it seems that the parameters (theta) for the generative process is defined in a model, while the parameters (phi) for the inference process is defined in a guide. In EM algorithm, we essentially use the same set of parameters for the E-step and M-step, and I don’t know how to reflect this in my code.

Here’s my model and guide:

def model(data, K=3): 
    num_obs,num_dim = data.shape

    with pyro.plate("data", size=num_obs):
        assignment = pyro.sample("assignment", dist.Categorical(torch.ones(K)/K))  
        means = torch.zeros(K, num_dim)
        scale_tril = torch.eye(num_dim).repeat(K, 1, 1)  
        pyro.sample("obs", dist.MultivariateNormal(means[assignment], scale_tril=scale_tril[assignment]), obs=data)

def guide(data, K=3):
    num_obs,num_dim = data.shape

    # define learnable parameters: mixture weights, means, covariance matrices
    weights = pyro.param("weights", torch.ones(K)/K, constraint=dist.constraints.simplex)
    means = pyro.param("means", torch.randn(K, num_dim))  
    chol_factor = pyro.param("chol_factor", torch.eye(num_dim).repeat(K, 1, 1))
    scale_diag = pyro.param("scale_diag", torch.ones(K, num_dim), constraint=dist.constraints.positive)
    scale_tril = chol_factor @ torch.diag_embed(scale_diag)   

    with pyro.plate("data", size=num_obs):
        # log(pi_k) + log(Normal(x_i | mu_k, Sigma_k))
        log_probs = torch.stack([dist.MultivariateNormal(means[k], scale_tril=scale_tril[k]).log_prob(data) for k in range(K)], dim=-1) + torch.log(weights)  
        assignment_probs = torch.exp(log_probs - torch.logsumexp(log_probs, dim=1, keepdim=True))
        assignment = pyro.sample("assignment", dist.Categorical(probs=assignment_probs))

Here is how I generate my data:

weights = np.array([2/8, 5/8, 1/8])
means = np.array([[1, 2], 
                  [5, 6],
                  [3, 1]])
covariances = np.array([[[1.0, 0.2], [0.2, 1.5]],
                        [[3, -0.3], [-2, 9]],
                        [[1.2, 0.4], [0.4, 1.3]]])
def sample_from_mixture(weights, means, covariances, num_samples=100):
    num_components = len(weights)
    components = np.random.choice(num_components, size=num_samples, p=weights)
    samples = np.array([
        np.random.multivariate_normal(means[comp], covariances[comp])
        for comp in components
    ])
    return samples

num_samples = 1000
samples = sample_from_mixture(weights, means, covariances, num_samples)
samples = torch.tensor(samples, dtype=torch.float) 

Here’s my training setup:

pyro.clear_param_store()
optim = Adam({"lr": 0.001, "betas": [0.8, 0.99]})
elbo = Trace_ELBO()
svi = SVI(model, guide, optim, elbo)

num_steps = 20000
for step in range(num_steps):
    loss = svi.step(samples, K=3)

Training results: the training got stuck at a local minimum, and only one mode was learned. I think one of the reason is I defined prior distributions for z (“assignment”) and x (“obs”) in the model.

Hi,

You can use the same pyro.param both in the model and in the guide. For example, in your case you can use pyro.param("weights") in the model as well:

def model(data, K=3): 
    num_obs,num_dim = data.shape

    with pyro.plate("data", size=num_obs):
        assignment = pyro.sample("assignment", dist.Categorical(pyro.param("weights"))) 

Note, if you use the parameter both in the model and the guide then you only need to initialize it in the guide.

1 Like

Thank you for the suggestion, now the code works with no problems. I have to define a custom ELBO function so the parameter is penalized only once. Mu updated code:

def model(data, K=3):
    num_obs,num_dim = data.shape
    # retrieve parameters
    weights = pyro.param("weights")
    means = pyro.param("means")
    chol_factor = pyro.param("chol_factor")
    scale_diag = pyro.param("scale_diag")
    scale_tril = chol_factor @ torch.diag_embed(scale_diag)

    with pyro.plate("data", size=num_obs):
        assignment = pyro.sample("assignment", dist.Categorical(weights))  
        pyro.sample("obs", dist.MultivariateNormal(means[assignment], scale_tril=scale_tril[assignment]), obs=data)

def guide(data, K=3):
    num_obs,num_dim = data.shape
    # define learnable parameters: mixture weights, means, covariance matrices
    weights = pyro.param("weights", torch.ones(K)/K, constraint=dist.constraints.simplex)
    means = pyro.param("means", torch.randn(K, num_dim))
    chol_factor = pyro.param("chol_factor", torch.eye(num_dim).repeat(K, 1, 1))
    scale_diag = pyro.param("scale_diag", torch.ones(K, num_dim), constraint=dist.constraints.positive)
    scale_tril = chol_factor @ torch.diag_embed(scale_diag)

    with pyro.plate("data", size=num_obs):
        # log(pi_k) + log(Normal(x_i | mu_k, Sigma_k))
        log_probs = torch.stack([dist.MultivariateNormal(means[k], scale_tril=scale_tril[k]).log_prob(data) for k in range(K)], dim=-1) + torch.log(weights)
        assignment_probs = torch.exp(log_probs - torch.logsumexp(log_probs, dim=1, keepdim=True))
        assignment = pyro.sample("assignment", dist.Categorical(probs=assignment_probs))

def EM_m_step(model, guide, *args, **kwargs):
    guide_trace = pyro.poutine.trace(guide).get_trace(*args, **kwargs)
    model_trace = pyro.poutine.trace(pyro.poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
    
    elbo = 0.0
    for site in model_trace.nodes.values():
        if site["type"] == "sample":
            elbo = elbo + site["fn"].log_prob(site["value"]).sum()

    return -elbo

Select the best initialization:

optim = Adam({"lr": 0.001, "betas": [0.8, 0.99]})

def initialize(seed):
    global svi
    pyro.set_rng_seed(seed)
    pyro.clear_param_store()
    svi = SVI(model, guide, optim, EM_m_step)
    return svi.loss(model, guide, samples, K=3)

# Choose the best among 100 random initializations.
loss, seed = min((initialize(seed), seed) for seed in range(100)) # only first argument, which is the loss, is being compared.
initialize(seed)
print(f"seed = {seed}, initial_loss = {loss}")

Training:

num_steps = 10000
losses = []
for step in range(num_steps):
    loss = svi.step(samples, K=3)
    losses.append(loss)

Visualization:

learned_weights = pyro.param("weights").detach()
learned_means = pyro.param("means").detach()
chol_factor = pyro.param("chol_factor").detach()
learned_scale_diag = pyro.param("scale_diag").detach()
learned_scale_tril = chol_factor @ torch.diag_embed(learned_scale_diag)

# Define grid for contour plot
x, y = np.linspace(-5, 10, 100), np.linspace(-5, 15, 100)
X, Y = np.meshgrid(x, y)
pos = np.dstack((X, Y))  # Correctly stacking X and Y for the 'pos' array

# Calculate mixture density
pdf = np.zeros(X.shape)
for i in range(learned_weights.shape[0]):
    mean = learned_means[i].numpy()
    cov = (learned_scale_tril[i] @ learned_scale_tril[i].t()).numpy()
    rv = multivariate_normal(mean=mean, cov=cov)  # Define the random variable
    pdf_component = rv.pdf(pos)  # Calculate the density for this component over the grid
    pdf += learned_weights[i].numpy() * pdf_component  # Properly scaling by the mixture weight

# Create contour plot
plt.figure(figsize=(8, 6))
plt.contour(X, Y, pdf, levels=50, cmap='viridis')
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.6, label='Samples')
plt.scatter(learned_means[:, 0], learned_means[:, 1], c='r', s=100, marker='x')  # Mark the means
plt.title('Contour Plot of the Learned Mixture Distribution')
plt.xlabel('X1')
plt.ylabel('X2')
plt.grid(True)
plt.show()

Thank you Pyro development team, you guys are the BEST! Pyro is the most powerful and flexible tool I ever used. Thank you!

1 Like