Hi, I’m trying to extend this GMM implementation (https://github.com/mcdickenson/em-gaussian/blob/master/em-gaussian-pyro.py) in three ways:
- Non-diagonal covariance matrices
- Partially observed labels (i.e. for some x, observe the label for which mixture it came from)
- Upweight the observed labels relative to the unobserved ones.
I was able to implement (1) without issue. However, I am persistently running into initialization issues with my guide when implementing (2), despite the shapes of each tensor/sample being what I’d expect. I’ve provided some reproducible code below:
import matplotlib.pyplot as plt
import numpy as np
import pyro
import pyro.distributions as dist
import torch
from matplotlib.patches import Ellipse
from pyro import poutine
from pyro.contrib.autoguide import AutoDelta
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from torch.distributions import constraints
@config_enumerate(default='parallel')
@poutine.broadcast
def model(data, labels):
# Global variables.
weights = pyro.param('weights', torch.FloatTensor([0.5]), constraint=constraints.unit_interval)
# scales = pyro.param('scales', torch.tensor([[[1., 0.], [0., 2.]], [[3., 0.], [0., 4.]]]), constraint=constraints.positive)
scale_trils = pyro.param('scale_trils',
torch.stack([torch.eye(2) for _ in range(K)]),
constraint=constraints.lower_cholesky)
locs = pyro.param('locs', torch.tensor([[1., 2.], [3., 4.]]))
with pyro.iarange('data', data.size(0)):
# Local variables.
for i in pyro.irange('data_loop', data.size(0)):
if labels[i] == -1: # Unlabeled data
assignment = pyro.sample(f'assignment_{i}', dist.Bernoulli(torch.ones(1) * weights)).to(torch.int64)
else: # Labeled data
assignment = labels[i].long()
pyro.sample(f'obs_{i}', dist.MultivariateNormal(locs[assignment], scale_tril=scale_trils[assignment]), obs=data[i])
@config_enumerate(default="parallel")
@poutine.broadcast
def full_guide(data, labels):
with pyro.iarange('data', len(data)):
unlabeled_mask = (labels == -1)
# for obs mask to work, need to make sure labels are 0/1 even at unobserved sites
labels[labels == -1] = 0
assignment_probs = pyro.param('assignment_probs', torch.ones(K) / K,
constraint=constraints.simplex)
# Expanding assignment_probs to match the batch shape of the data plate
assignment_probs = assignment_probs.expand(len(data), -1)
pyro.sample('assignment', dist.Bernoulli(assignment_probs),
infer={"enumerate": "sequential"}, obs=labels, obs_mask=unlabeled_mask)
def initialize(data, labels):
pyro.clear_param_store()
optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_iarange_nesting=1)
svi = SVI(model, full_guide, optim, loss=elbo)
# Initialize weights to uniform.
pyro.param('auto_weights', 0.5 * torch.ones(K), constraint=constraints.simplex)
# Assume half of the data variance is due to intra-component noise.
var = (data.var() / 2).sqrt()
pyro.param('auto_scale_trils',
torch.stack([torch.eye(2) for _ in range(K)]),
constraint=constraints.lower_cholesky)
# pyro.param('auto_scale', torch.tensor([var]*4), constraint=constraints.positive)
# Initialize means from a subsample of data.
pyro.param('auto_locs', data[torch.multinomial(torch.ones(len(data)) / len(data), K)])
loss = svi.loss(model, full_guide, data, labels)
return loss, svi
def get_samples(labeled_fraction=0.2):
num_samples = 100
# 2 clusters
# note that both covariance matrices are diagonal
mu1 = torch.tensor([0., 5.])
sig1 = torch.tensor([[2., 1.], [1., 2.]])
mu2 = torch.tensor([5., 0.])
sig2 = torch.tensor([[4., 3.], [3., 4.]])
# generate samples
dist1 = dist.MultivariateNormal(mu1, sig1)
samples1 = [pyro.sample('samples1', dist1) for _ in range(num_samples)]
dist2 = dist.MultivariateNormal(mu2, sig2)
samples2 = [pyro.sample('samples2', dist2) for _ in range(num_samples)]
data = torch.cat((torch.stack(samples1), torch.stack(samples2)))
labels = torch.cat((torch.zeros(num_samples, dtype=torch.long),
torch.ones(num_samples, dtype=torch.long)))
# Randomly mask a fraction of labels
num_labeled = int(len(data) * labeled_fraction)
labeled_indices = torch.randperm(len(data))[:num_labeled]
labels_masked = torch.full(labels.shape, -1, dtype=torch.long) # Ensure Long type
labels_masked[labeled_indices] = labels[labeled_indices]
return data, labels_masked
# return data
def plot(data, mus=None, scale_trils=None, sigmas=None, colors='black', figname='fig.png'):
# Create figure
fig = plt.figure()
# Plot data
x = data[:, 0]
y = data[:, 1]
plt.scatter(x, y, 24, c=colors)
# Plot cluster centers
if mus is not None:
x = [float(m[0]) for m in mus]
y = [float(m[1]) for m in mus]
plt.scatter(x, y, 99, c='red')
# Plot ellipses for each cluster
if sigmas is not None:
if scale_trils is not None:
# Reconstruct the full covariance matrix
for sig_ix in range(K):
cov = scale_trils[sig_ix] @ scale_trils[sig_ix].T
ax = fig.gca()
# cov = np.array(sigmas[sig_ix])
lam, v = np.linalg.eig(cov)
lam = np.sqrt(lam)
ell = Ellipse(xy=(x[sig_ix], y[sig_ix]),
width=lam[0]*4, height=lam[1]*4,
angle=np.rad2deg(np.arccos(v[0, 0])),
color='blue')
ell.set_facecolor('none')
ax.add_artist(ell)
# Save figure
fig.savefig(figname)
if __name__ == "__main__":
pyro.enable_validation(True)
pyro.set_rng_seed(42)
# Create our model with a fixed number of components
K = 2
data, labels = get_samples()
global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs', 'scales']))
global_guide = config_enumerate(global_guide, 'parallel')
print("initializing...")
_, svi = initialize(data, labels)
print("initialized")
true_colors = [0] * 100 + [1] * 100
plot(data, colors=true_colors, figname='pyro_init.png')
for i in range(501):
print(i)
# scale_trils = pyro.param('scale_trils')
svi.step(data, labels)
if i % 50 == 0:
locs = pyro.param('locs')
scale_trils = pyro.param('scale_trils')
weights = pyro.param('weights')
assignment_probs = pyro.param('assignment_probs')
print("locs: {}".format(locs))
print("scales: {}".format(scale_trils))
print('weights = {}'.format(weights))
print('assignments: {}'.format(assignment_probs))
# todo plot data and estimates
assignments = np.uint8(np.round(assignment_probs.data))
plot(data, locs.data, scale_trils.data, assignments, figname='pyro_iteration{}.png'.format(i))
You can compare this code to the one in the GitHub link as well — the only changes are the non-diagonal covariance matrices and the partially observed labels.