Related: this post.
CrossCat generalizes the Dirichlet Process Mixture Model (DPMM) by fitting individual DPMMs to disjoint feature subsets (views). The assignment of features to views is also determined by a Dirichlet Process. Because the clustering within a view depends on the features assigned to the view, downstream coupling occurs. Thus, only the clusters within a view can be enumerated out in the model.
Our current strategy uses guide-side enumeration for the views and model-side enumeration for the clusters. View assignment probabilities for each feature are learned individually. Unfortunately, the model fails to recognize different views. We would love to get this working in Pyro, but our current assessment is that inference doesn’t work with the tools immediately at hand. It seems like there’s also been some past work in this area!
If there is something we’re missing, we would love feedback!
Here is a reproducible example for continuous data:
import matplotlib.pyplot as plt
import numpy as np
from numpy.random import default_rng
from tqdm import tqdm
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro.infer import SVI, TraceEnum_ELBO
from torch.nn.functional import pad
from pyro.optim import Adam
assert pyro.__version__.startswith('1.8.4')
def gen_views(cluster_scheme, n_rows=10):
"""cluster scheme takes the form of: [(n_clusters, n_columns), (n_clusters, n_colums)]"""
views = []
for n_clusters, n_cols in cluster_scheme:
views.append(gen_view(n_rows, n_clusters, n_cols))
return np.concatenate(views, axis=1)
def gen_view(n_rows, n_clusters=2, n_cols=1, spacing=3, scale=.3):
rng = default_rng()
cluster_lengths = [n_rows // n_clusters for i in range(n_clusters)]
cluster_lengths[0] += n_rows % n_clusters
clusters = []
for i, cluster_len in enumerate(cluster_lengths):
clusters.append(rng.normal(loc=spacing*i, scale=scale, size=(n_cols, cluster_len)))
return np.concatenate(clusters, axis=1).T
def stickbreak(v, dim=-1):
assert dim in [-1, 0, 1], "only dims -1, 0, 1 supported"
cumprod_one_minus_v = torch.cumprod(1 - v, dim)
if dim == 0:
v_one = pad(v, (0,0,0,1), value=1)
one_c = pad(cumprod_one_minus_v, (0, 0, 1, 0), value=1)
else:
v_one = pad(v, (0, 1), value=1)
one_c = pad(cumprod_one_minus_v, (1, 0), value=1)
return v_one * one_c
def gen_uniform_stick_weights(n_cats):
weights = torch.zeros(n_cats - 1)
weights[0] = 1/n_cats
for i in range(1, n_cats - 1):
stick_remainder = 1 - weights[i - 1]
weights[i] = (1/(1+n_cats-i)) / stick_remainder
return weights
def gen_uniform_stick_each_view(n_cats, n_views):
return torch.vstack([gen_uniform_stick_weights(n_cats) for view in range(n_views)]).T
# --- generate sample data
data = torch.from_numpy(gen_views([(3, 2), (2, 2)], n_rows = 100))
data = ((data - data.mean(axis=0)) / data.std(axis=0)).type(torch.float32)
n_obs, n_features = data.shape
plt.matshow(data)
plt.show()
# ---
# --- specify truncation level
n_views = 3
n_cats = 10
# ---
def model(data):
alpha_v = pyro.sample('alpha_v', dist.Gamma(1, 1))
with pyro.plate('views - 1', n_views - 1):
view_priors = stickbreak(pyro.sample('v_v', dist.Beta(1, alpha_v)))
with pyro.plate('views', n_views):
alpha_c = pyro.sample('alpha_c', dist.Gamma(1, 1))
with pyro.plate('cats - 1', n_cats - 1):
cluster_priors = stickbreak(pyro.sample('v_c', dist.Beta(1, alpha_c)), dim=0)
with pyro.plate("features", n_features):
with pyro.plate(f'cats', n_cats):
mu = pyro.sample('mu', dist.Normal(0, 1))
sigma_sq = pyro.sample('sigma_sq', dist.InverseGamma(1, 1))
with pyro.plate('features_', n_features):
feature_views = pyro.sample(f"feature_views", dist.Categorical(view_priors))
unique_views = torch.unique(feature_views)
for view_i in pyro.plate('unique_views', len(unique_views)):
unique_view = unique_views[view_i] # get active view
active_features = torch.nonzero(unique_view == feature_views).flatten() # get active feature indices for that view
with pyro.plate(f"N_{unique_view}", n_obs):
cats = pyro.sample(f'cat_samps_{unique_view}', dist.Categorical(cluster_priors.T[unique_view]), infer={"enumerate":"parallel"})
for feature_i in pyro.plate(f'active_features_{unique_view}', len(active_features)):
active_feature = active_features[feature_i]
curr_x = pyro.sample(f"X_{active_feature}", dist.Normal(mu[cats, active_feature, unique_view], sigma_sq[cats, active_feature, unique_view]), obs=data[:, active_feature])
init_alpha_c = dist.Gamma(torch.ones([n_views]), torch.ones([n_views])).sample()
init_mu = dist.Uniform(-2.*torch.ones([n_cats, n_features, n_views]), 2.*torch.ones([n_cats, n_features, n_views])).sample()
init_sigma_sq = dist.InverseGamma(torch.ones([n_cats, n_features, n_views]), torch.ones([n_cats, n_features, n_views])).sample()
init_v_v = gen_uniform_stick_weights(n_views)
init_v_c = gen_uniform_stick_each_view(n_cats, n_views)
def guide(data):
loc_alpha_v = pyro.param('loc_alpha_v', dist.Gamma(1, 1), constraint=constraints.positive)
alpha_v = pyro.sample('alpha_v', dist.Delta(loc_alpha_v))
with pyro.plate('views - 1', n_views - 1):
loc_view_priors = pyro.param('loc_v_v', init_v_v, constraint=constraints.unit_interval)
view_priors = stickbreak(pyro.sample('v_v', dist.Delta(loc_view_priors)))
with pyro.plate('views', n_views):
loc_alpha_c = pyro.param('loc_alpha_c', init_alpha_c, constraint=constraints.positive)
alpha_c = pyro.sample('alpha_c', dist.Delta(loc_alpha_c))
with pyro.plate('cats - 1', n_cats - 1):
loc_cluster_priors = pyro.param('loc_v_c', init_v_c, constraint=constraints.unit_interval)
cluster_priors = stickbreak(pyro.sample('v_c', dist.Delta(loc_cluster_priors)), dim=0)
with pyro.plate("features", n_features):
with pyro.plate(f'cats', n_cats):
loc_mu = pyro.param('loc_mu', init_mu)
pyro.sample(f'mu', dist.Delta(loc_mu))
loc_sigma_sq = pyro.param('loc_sigma_sq', init_sigma_sq, constraint=constraints.positive)
pyro.sample(f'sigma_sq', dist.Delta(loc_sigma_sq))
with pyro.plate("features_", n_features):
view_assign = pyro.param('view_probs', torch.ones(n_features, n_views) / n_views, constraint=constraints.unit_interval)
feature_views = pyro.sample(f"feature_views", dist.Categorical(view_assign), infer={"enumerate":"sequential"})
elbo = TraceEnum_ELBO(max_plate_nesting=3)
svi = SVI(model, guide, Adam({'lr': 1e-2}), loss=elbo)
losses = []
for step in tqdm(range(1000)):
loss = svi.step(data)
losses.append(loss)
plt.plot(losses)
plt.show()