Extending GMMs for regressions-within-clusters

Hi,

I am attempting to tweak a Gaussian Mixture Model implementation, to do a within-cluster-regression given some target variable, in addition to the standard clustering scheme. The motivation is to do simple regression, while jointly learning structure in the covariates.

I’m struggling to get it implemented correctly in higher dimensions. Would appreciate any help!

Background

(I am using Bishop’s PRML notation).

In the standard GMM formulation, we can write the probability of observing a data point as follows:

p(x) = \sum_{z} p(z)p(x|z)

where z is a latent variable, more precisely a binary random vector (or one-hot-encoding RV) that controls the assignment of a point to a cluster.

We can re-write this as:

p(x) = \sum_{k=1}^{K} \pi_k * p(x|z=k), where k is the cluster index, \pi_k is the probability mass function of z_k = 1. Here K is the number of clusters (fixed).

Each cluster index can have its own \mu_{k}, and \Sigma_{k} parameters, that fully parameterizes p(x|z) [k indexing the k^{th} cluster]

In my setup, there exists some other random variable Y, that we may observe with X. We can write the joint distribution of both (X, Y) as:

p(x, y) = \sum_{z} p(z)p(x, y|z) = \sum_{z} p(z)p(x|z)p(y|x, z)

Assume a simple gaussian form for p(y|x, z). We’ll also assume the mean function be parameterized by a simple linear model, \mathbf{x^T}*\theta_{k} (ignore intercept). \theta_{k} is a vector that belongs to each component.

Note that each gaussian component will now have three parameters: \mu_{k}, \Sigma_{k}, \theta_{k}

Stated plainly in a generative sense, given some value of z, we can generate x. Then given that value of z and x, we can use the relevant \theta_k vector to generate y.

NumPyro Implementation

I am struggling to write the p(y|x, z) part of the model correctly. I am extending the numpyro.plate() construct as in the example GMM implementation, but I have not been able to fully understand how the indices are affected by marginalization to implement the simple function reliably (\mathbf{x^T}*\theta_{k}). I have ran the GMM piece only, and it works just fine. (Mostly) self-reproducing example is below.

import os
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.contrib.funsor import config_enumerate, infer_discrete
from numpyro.infer import MCMC, NUTS, init_to_median, init_to_uniform, init_to_value

# Define some test data to study
# Three multivariate Gaussians

# Constants

n_samples = 1000
K = 4
seeds = jax.random.split(jax.random.PRNGKey(42), 100)

# Means of each of the three components

mu_comps = jnp.array([[-1, -1], [+1, -1], [-1, +1], [+1, +1]])
coeff = jnp.array([[1, 4], [3, 2], [4, 1], [-3, -2]])
var = 0.1

# Create variables to hold

X_clust = jnp.zeros((n_samples, 2))
y_train = jnp.zeros((n_samples,))
# Probability forward model

for i in range(0, n_samples):

    # Draw samples

    seed = seeds[i]
    pi = dist.Dirichlet(jnp.ones(K) / float(K)).sample(seed)
    z = dist.Categorical(pi).sample(seed)
    x = dist.MultivariateNormal(mu_comps[z, :], var * jnp.eye(2)).sample(seed)
    y = jnp.multiply(coeff[z, :], x).sum()  # + bias[z]

    # Assign samples to arrays

    X_clust = X_clust.at[i, :].set(x)
    y_train = y_train.at[i].set(y)

@config_enumerate
def model(K, X_clust=None, X_reg=None, y=None):
    """
    Credit to the formulation in for base GMM in high dimensions: https://forum.pyro.ai/t/mixture-model-with-discrete-data-in-numpyro/5983
    Assume for simplicity that X_clust = X_reg.
    """

    # Constants

    N = X_clust.shape[0]  # Number of datapoints
    D = X_clust.shape[1]  # Number of input dimensions for clustering
    R = X_reg.shape[1]  # Number of input dimensions for regression

    # Assignment probability vector for the latent variable

    cluster_probs = numpyro.sample(
        "cluster_probs", dist.Dirichlet(jnp.ones(K) / float(K))
    )

    # Parameters associated with each Gaussian component (k \in K)
    # Covariance parameter is excluded for simplicity

    with numpyro.plate("mixture_components", K, dim=None):

        locs = numpyro.sample(
            "locs", dist.MultivariateNormal(jnp.zeros(D), 10 * jnp.eye(D))
        )
        thetas = numpyro.sample(
            "thetas", dist.MultivariateNormal(jnp.zeros(R), 10 * jnp.eye(R))
        )
        intercepts = numpyro.sample("intercepts", dist.Normal(0, 10))

    # Generative process for each observation through a latent variable
    # The use of config_enumerate marginalizes out the discrete latent variable

    with numpyro.plate("data", N, dim=None) as ind:

        # z ~ P(z)

        assignment = numpyro.sample("assignment", dist.Categorical(cluster_probs))

        # X|Z ~ P(x|z)

        x_obs = numpyro.sample(
            "x_obs",
            dist.MultivariateNormal(
                locs[assignment, :],
                covariance_matrix=0.1 * jnp.eye(2),
            ),
            obs=X_clust,
        )

        # Observation likelihood: P(y|x, z)
        # I am not sure the this is correct.
        # The goal is to multiply each data point by the assignment theta, and sum over all values of assignment

        X_batch = X_reg[ind]
        y_pred = jnp.multiply(X_batch, thetas[assignment, :]).sum(axis=1)

        numpyro.sample(
            "y_obs",
            dist.Normal(
                loc=y_pred + intercepts[assignment],
                scale=0.5,  # Fixed for ease of implmentation
            ),
            obs=y,
        )


# Inference

kernel = NUTS(model, max_tree_depth=10, init_strategy=init_to_median())
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(
    jax.random.PRNGKey(42),
    K=4,
    X_clust=X_clust,
    X_reg=X_clust,
    y=y_train,
)
mcmc.print_summary()

Please let me know if anything is unclear.

What you are doing is very similar to what I’m trying to do, except that I have a discrete mixture model and a logistic regression response rather than your GMM and linear regression. Coming from a frequentist statistics background, the way I’m thinking about these models is that they are essentially multilevel models where you don’t know the levels in advance.

Looking at your implementation, it looks like you want you have separate intercepts and gradients for each cluster found by the mixture model. Do I have that correct? My model has a response sector with global gradients but an intercept per cluster but should be easily generalisable to a gradient per cluster as well.

I posted my model code in this thread, which you may find helpful: Model with a joint posterior distribution

1 Like

Hi @jim,

Thanksx1000 for your thoughts! Your previous issues have already been very helpful.

I agree, our formulation is very much the same. Your description of our difference in our within-cluster regression parameters is also correct. Also agree with your intuition that these are similar to multi-level/hierarchical models, with unknown index assignments.

I just modified my “regression chunk” inside plate “data” to the follow the axis of summation as in your code:


y_pred =  jnp.sum(thetas[assignment, :] * X_reg, axis=-1) + intercepts[assignment]

numpyro.sample(
    "y_obs",
    dist.Normal(
        loc=y_pred,
        scale=0.1,  # Fixed for ease of implmentation
    ),
    obs=y,
)

And voila, the sampler is able to infer the correct means and regression coefficients for the synthetic data! I’m still going to scratch my head to understand why axis=-1 is the correct index. The way the plate construct appends dimensions is quite confusing.

Side note, you might find this paper by Shahbaba and Neal interesting - they extend this sort of approach to use bayesian non-parametrics for the clustering piece. For me that’s a natural extension, but I’m not sure if it might also apply to your case. AFAIK, this may not be possible to implement in Numpyro.

https://jmlr.org/papers/v10/shahbaba09a.html

1 Like

Nice one! Thanks for the reference too :+1:

We actually built a small package that makes it easier to build these kinds of regression mixture models in NumPyro: GitHub - compmem/spamr: Mixture regression models for NumPyro.

2 Likes

@amifalk Thanks for sharing! I missed this in my search. Nice to see you arrive at exactly the same formulation.

Do you have any plans to extend this to the non-parameteric case (as in the Shahbaba and Neal paper above)? I noticed you have a few issues here on using DPMM, so could be a very cool extension. It’s certainly on my agenda in the near future.

If you check the spamr.py file in that repo, you should see a function that allows you to do exactly this. We use the truncated stickbreaking representation of the DP, and structured the core code to be modular so you can copy and paste chunks of spamr even if you don’t want to rely completely on our sparse formulation (e.g. sample_ordered_beta will help with identifiability if you’re running inference with MCMC).

I’ll pin the version number in the dependencies after the latest version of NumPyro releases so it’s less hacky to use (we currently rely on a bug fix sitting on the main branch).

1 Like

Apologies, I only glanced at your package quickly this morning, and assumed it was based on fixed finite components. Really neat bit of work! I wasn’t quite aware of the approximate stickbreaking representation of DPs, but it makes a lot of sense, and wonderful that it makes the problem amenable to HMC.

Hi @amifalk,

I spent some time studying your model. While I haven’t entirely understood it, but I suspect it takes the following rough form:

F ~ DP(alpha, Fo)
theta_i | F ~ F, for i = 1, …, n (theta = (lam, beta))
X_i|theta_i ~ f(x_i|theta_i), for i = 1, …, n (Covariate likelihood)
Y_i|X_i, theta_i ~ g(y_i|X_i, theta_i, noise) (Response likelihood)

And you’re using a finite stick-breaking approximation to draw from a DP.

Why is it that you only have one “obs” statement (line 107 in spamr.py) for only the response variable y? I would’ve thought that you’d have an additional likelihood term corresponding to the “covariate likelihood” (X)? Am I missing something?

That’s mostly right, but we don’t model the covariate likelihood as in Shahbaba. We’re only modeling P(y|X)

1 Like

Thanks! It doesn’t seem like too much overhead to incorporate the covariate likelihood in the way your code is structured, so I’ll trial it out on my end.

@amifalk just discovered this old-ish paper that also accomplishes a version of your work!

Thanks for sharing - I believe this kind of mixture regression has actually existed since the late 70’s. And for other similar formulations see also, for example:

1 Like