Hi,
I am attempting to tweak a Gaussian Mixture Model implementation, to do a within-cluster-regression given some target variable, in addition to the standard clustering scheme. The motivation is to do simple regression, while jointly learning structure in the covariates.
I’m struggling to get it implemented correctly in higher dimensions. Would appreciate any help!
Background
(I am using Bishop’s PRML notation).
In the standard GMM formulation, we can write the probability of observing a data point as follows:
p(x) = \sum_{z} p(z)p(x|z)
where z is a latent variable, more precisely a binary random vector (or one-hot-encoding RV) that controls the assignment of a point to a cluster.
We can re-write this as:
p(x) = \sum_{k=1}^{K} \pi_k * p(x|z=k), where k is the cluster index, \pi_k is the probability mass function of z_k = 1. Here K is the number of clusters (fixed).
Each cluster index can have its own \mu_{k}, and \Sigma_{k} parameters, that fully parameterizes p(x|z) [k indexing the k^{th} cluster]
In my setup, there exists some other random variable Y, that we may observe with X. We can write the joint distribution of both (X, Y) as:
p(x, y) = \sum_{z} p(z)p(x, y|z) = \sum_{z} p(z)p(x|z)p(y|x, z)
Assume a simple gaussian form for p(y|x, z). We’ll also assume the mean function be parameterized by a simple linear model, \mathbf{x^T}*\theta_{k} (ignore intercept). \theta_{k} is a vector that belongs to each component.
Note that each gaussian component will now have three parameters: \mu_{k}, \Sigma_{k}, \theta_{k}
Stated plainly in a generative sense, given some value of z, we can generate x. Then given that value of z and x, we can use the relevant \theta_k vector to generate y.
NumPyro Implementation
I am struggling to write the p(y|x, z) part of the model correctly. I am extending the numpyro.plate()
construct as in the example GMM implementation, but I have not been able to fully understand how the indices are affected by marginalization to implement the simple function reliably (\mathbf{x^T}*\theta_{k}). I have ran the GMM piece only, and it works just fine. (Mostly) self-reproducing example is below.
import os
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.contrib.funsor import config_enumerate, infer_discrete
from numpyro.infer import MCMC, NUTS, init_to_median, init_to_uniform, init_to_value
# Define some test data to study
# Three multivariate Gaussians
# Constants
n_samples = 1000
K = 4
seeds = jax.random.split(jax.random.PRNGKey(42), 100)
# Means of each of the three components
mu_comps = jnp.array([[-1, -1], [+1, -1], [-1, +1], [+1, +1]])
coeff = jnp.array([[1, 4], [3, 2], [4, 1], [-3, -2]])
var = 0.1
# Create variables to hold
X_clust = jnp.zeros((n_samples, 2))
y_train = jnp.zeros((n_samples,))
# Probability forward model
for i in range(0, n_samples):
# Draw samples
seed = seeds[i]
pi = dist.Dirichlet(jnp.ones(K) / float(K)).sample(seed)
z = dist.Categorical(pi).sample(seed)
x = dist.MultivariateNormal(mu_comps[z, :], var * jnp.eye(2)).sample(seed)
y = jnp.multiply(coeff[z, :], x).sum() # + bias[z]
# Assign samples to arrays
X_clust = X_clust.at[i, :].set(x)
y_train = y_train.at[i].set(y)
@config_enumerate
def model(K, X_clust=None, X_reg=None, y=None):
"""
Credit to the formulation in for base GMM in high dimensions: https://forum.pyro.ai/t/mixture-model-with-discrete-data-in-numpyro/5983
Assume for simplicity that X_clust = X_reg.
"""
# Constants
N = X_clust.shape[0] # Number of datapoints
D = X_clust.shape[1] # Number of input dimensions for clustering
R = X_reg.shape[1] # Number of input dimensions for regression
# Assignment probability vector for the latent variable
cluster_probs = numpyro.sample(
"cluster_probs", dist.Dirichlet(jnp.ones(K) / float(K))
)
# Parameters associated with each Gaussian component (k \in K)
# Covariance parameter is excluded for simplicity
with numpyro.plate("mixture_components", K, dim=None):
locs = numpyro.sample(
"locs", dist.MultivariateNormal(jnp.zeros(D), 10 * jnp.eye(D))
)
thetas = numpyro.sample(
"thetas", dist.MultivariateNormal(jnp.zeros(R), 10 * jnp.eye(R))
)
intercepts = numpyro.sample("intercepts", dist.Normal(0, 10))
# Generative process for each observation through a latent variable
# The use of config_enumerate marginalizes out the discrete latent variable
with numpyro.plate("data", N, dim=None) as ind:
# z ~ P(z)
assignment = numpyro.sample("assignment", dist.Categorical(cluster_probs))
# X|Z ~ P(x|z)
x_obs = numpyro.sample(
"x_obs",
dist.MultivariateNormal(
locs[assignment, :],
covariance_matrix=0.1 * jnp.eye(2),
),
obs=X_clust,
)
# Observation likelihood: P(y|x, z)
# I am not sure the this is correct.
# The goal is to multiply each data point by the assignment theta, and sum over all values of assignment
X_batch = X_reg[ind]
y_pred = jnp.multiply(X_batch, thetas[assignment, :]).sum(axis=1)
numpyro.sample(
"y_obs",
dist.Normal(
loc=y_pred + intercepts[assignment],
scale=0.5, # Fixed for ease of implmentation
),
obs=y,
)
# Inference
kernel = NUTS(model, max_tree_depth=10, init_strategy=init_to_median())
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(
jax.random.PRNGKey(42),
K=4,
X_clust=X_clust,
X_reg=X_clust,
y=y_train,
)
mcmc.print_summary()
Please let me know if anything is unclear.