Guide-side enumeration strategy fails for DPMM extension (CrossCat)

Related: this post.

CrossCat generalizes the Dirichlet Process Mixture Model (DPMM) by fitting individual DPMMs to disjoint feature subsets (views). The assignment of features to views is also determined by a Dirichlet Process. Because the clustering within a view depends on the features assigned to the view, downstream coupling occurs. Thus, only the clusters within a view can be enumerated out in the model.

Our current strategy uses guide-side enumeration for the views and model-side enumeration for the clusters. View assignment probabilities for each feature are learned individually. Unfortunately, the model fails to recognize different views. We would love to get this working in Pyro, but our current assessment is that inference doesn’t work with the tools immediately at hand. It seems like there’s also been some past work in this area!

If there is something we’re missing, we would love feedback!

Here is a reproducible example for continuous data:

import matplotlib.pyplot as plt
import numpy as np
from numpy.random import default_rng
from tqdm import tqdm

import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro.infer import SVI, TraceEnum_ELBO
from torch.nn.functional import pad
from pyro.optim import Adam

assert pyro.__version__.startswith('1.8.4')

def gen_views(cluster_scheme, n_rows=10):
    """cluster scheme takes the form of: [(n_clusters, n_columns), (n_clusters, n_colums)]"""
    views = []
    for n_clusters, n_cols in cluster_scheme:
        views.append(gen_view(n_rows, n_clusters, n_cols)) 

    return np.concatenate(views, axis=1)

def gen_view(n_rows, n_clusters=2, n_cols=1, spacing=3, scale=.3):
    rng = default_rng()

    cluster_lengths = [n_rows // n_clusters for i in range(n_clusters)]
    cluster_lengths[0] += n_rows % n_clusters

    clusters = []

    for i, cluster_len in enumerate(cluster_lengths):
        clusters.append(rng.normal(loc=spacing*i, scale=scale, size=(n_cols, cluster_len)))

    return np.concatenate(clusters, axis=1).T

def stickbreak(v, dim=-1):
    assert dim in [-1, 0, 1], "only dims -1, 0, 1 supported"

    cumprod_one_minus_v = torch.cumprod(1 - v, dim)
    if dim == 0:
        v_one = pad(v, (0,0,0,1), value=1)
        one_c = pad(cumprod_one_minus_v, (0, 0, 1, 0), value=1)
    else:   
        v_one = pad(v, (0, 1), value=1)
        one_c = pad(cumprod_one_minus_v, (1, 0), value=1)
    return v_one * one_c

def gen_uniform_stick_weights(n_cats):
    weights = torch.zeros(n_cats - 1)
    weights[0] = 1/n_cats

    for i in range(1, n_cats - 1):
        stick_remainder = 1 - weights[i - 1]
        weights[i] = (1/(1+n_cats-i)) / stick_remainder 

    return weights

def gen_uniform_stick_each_view(n_cats, n_views):
    return torch.vstack([gen_uniform_stick_weights(n_cats) for view in range(n_views)]).T

# --- generate sample data
data = torch.from_numpy(gen_views([(3, 2), (2, 2)], n_rows = 100))
data = ((data - data.mean(axis=0)) / data.std(axis=0)).type(torch.float32)
n_obs, n_features = data.shape
plt.matshow(data)
plt.show()
# ---

# --- specify truncation level
n_views = 3
n_cats = 10
# ---

def model(data):
    alpha_v = pyro.sample('alpha_v', dist.Gamma(1, 1))

    with pyro.plate('views - 1', n_views - 1):
        view_priors = stickbreak(pyro.sample('v_v', dist.Beta(1, alpha_v)))

    with pyro.plate('views', n_views):
        alpha_c = pyro.sample('alpha_c', dist.Gamma(1, 1))        
        
        with pyro.plate('cats - 1', n_cats - 1):
            cluster_priors = stickbreak(pyro.sample('v_c', dist.Beta(1, alpha_c)), dim=0)

        with pyro.plate("features", n_features):
            with pyro.plate(f'cats', n_cats):
                mu = pyro.sample('mu', dist.Normal(0, 1))
                sigma_sq = pyro.sample('sigma_sq', dist.InverseGamma(1, 1))

    with pyro.plate('features_', n_features):
        feature_views = pyro.sample(f"feature_views", dist.Categorical(view_priors))

    unique_views = torch.unique(feature_views)

    for view_i in pyro.plate('unique_views', len(unique_views)):
        unique_view = unique_views[view_i] # get active view
        active_features = torch.nonzero(unique_view == feature_views).flatten() # get active feature indices for that view 

        with pyro.plate(f"N_{unique_view}", n_obs):
            cats = pyro.sample(f'cat_samps_{unique_view}', dist.Categorical(cluster_priors.T[unique_view]), infer={"enumerate":"parallel"})
            
            for feature_i in pyro.plate(f'active_features_{unique_view}', len(active_features)):
               active_feature = active_features[feature_i]
               curr_x = pyro.sample(f"X_{active_feature}", dist.Normal(mu[cats, active_feature, unique_view], sigma_sq[cats, active_feature, unique_view]), obs=data[:, active_feature])

init_alpha_c = dist.Gamma(torch.ones([n_views]), torch.ones([n_views])).sample()
init_mu = dist.Uniform(-2.*torch.ones([n_cats, n_features, n_views]), 2.*torch.ones([n_cats, n_features, n_views])).sample()
init_sigma_sq = dist.InverseGamma(torch.ones([n_cats, n_features, n_views]), torch.ones([n_cats, n_features, n_views])).sample()
init_v_v = gen_uniform_stick_weights(n_views)
init_v_c = gen_uniform_stick_each_view(n_cats, n_views)

def guide(data):
    loc_alpha_v = pyro.param('loc_alpha_v', dist.Gamma(1, 1), constraint=constraints.positive)
    alpha_v = pyro.sample('alpha_v', dist.Delta(loc_alpha_v))

    with pyro.plate('views - 1', n_views - 1):
       loc_view_priors = pyro.param('loc_v_v', init_v_v, constraint=constraints.unit_interval)
       view_priors = stickbreak(pyro.sample('v_v', dist.Delta(loc_view_priors)))

    with pyro.plate('views', n_views):
        loc_alpha_c = pyro.param('loc_alpha_c', init_alpha_c, constraint=constraints.positive)
        alpha_c = pyro.sample('alpha_c', dist.Delta(loc_alpha_c))

        with pyro.plate('cats - 1', n_cats - 1):
            loc_cluster_priors = pyro.param('loc_v_c', init_v_c, constraint=constraints.unit_interval)
            cluster_priors = stickbreak(pyro.sample('v_c', dist.Delta(loc_cluster_priors)), dim=0)

        with pyro.plate("features", n_features):
            with pyro.plate(f'cats', n_cats):
                loc_mu = pyro.param('loc_mu', init_mu)
                pyro.sample(f'mu', dist.Delta(loc_mu))
                
                loc_sigma_sq = pyro.param('loc_sigma_sq', init_sigma_sq, constraint=constraints.positive)
                pyro.sample(f'sigma_sq', dist.Delta(loc_sigma_sq))

    with pyro.plate("features_", n_features):
        view_assign = pyro.param('view_probs', torch.ones(n_features, n_views) / n_views, constraint=constraints.unit_interval)
        feature_views = pyro.sample(f"feature_views", dist.Categorical(view_assign), infer={"enumerate":"sequential"})

elbo = TraceEnum_ELBO(max_plate_nesting=3)
svi = SVI(model, guide, Adam({'lr': 1e-2}), loss=elbo)

losses = []
for step in tqdm(range(1000)):
    loss = svi.step(data)
    losses.append(loss)

plt.plot(losses)
plt.show()

It’s interesting to note that if you fix the view assignments, the model will fit the clusters in each individual view properly.

cc @fritzo