Extending GMM example to higher dimensional datasets

Hi,

I was trying to expand the GMM tutorial to higher dimensions. I noticed there is a previous issue in the pyro context, but I don’t think the solution offered exists (.independent() method no longer exists).

Here is a minimal reproducible example of my attempt so far:

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.funsor import config_enumerate, infer_discrete
from numpyro.infer import MCMC, NUTS

# Generate some random data

data = dist.Normal(0, 1).sample(jax.random.PRNGKey(42), (100, 13))
print(data.shape)

# Define NumPyro model for GMMs in high dimensions


@config_enumerate
def model(data):

    # Constants

    N = data.shape[0]  # Number of datapoints
    D = 13  # Number of input dimensions
    K = 3  # Number of mixture components

    # Global parameters
    # TODO: make scale/covariance vary by component

    weights = numpyro.sample("weights", dist.Dirichlet(0.5 * jnp.ones(K)))
    scale = numpyro.sample("scale", dist.LogNormal(0.0, 2.0))

    # Plate construct for mean vector of each gaussian component

    with numpyro.plate("components", K):
        locs = numpyro.sample(
            "locs", dist.Normal(jnp.ones(shape=(D, 1)), 10.0)
        )  # Shape = DxK

    # Plate construct to loop over data

    with numpyro.plate("data", N):
        assignment = numpyro.sample("assignment", dist.Categorical(weights))
        numpyro.sample(
            "obs",
            dist.MultivariateNormal(locs[:, assignment], scale * jnp.eye(D, D)),
            obs=data,
        )


# Inference

kernel = NUTS(model, max_tree_depth=10)
mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
mcmc.run(jax.random.PRNGKey(42), data)
mcmc.print_summary()

The error that I’m getting is:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/jax/_src/lax/lax.py:149, in broadcast_shapes(*shapes)
    148 try:
--> 149   return _broadcast_shapes_cached(*shapes)
    150 except:

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/jax/_src/util.py:287, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
    286 else:
--> 287   return cached(config.trace_context(), *args, **kwargs)

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/jax/_src/util.py:280, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
    278 @functools.lru_cache(max_size)
    279 def cached(_, *args, **kwargs):
--> 280   return f(*args, **kwargs)

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/jax/_src/lax/lax.py:155, in _broadcast_shapes_cached(*shapes)
    153 @cache()
    154 def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]:
--> 155   return _broadcast_shapes_uncached(*shapes)

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/jax/_src/lax/lax.py:171, in _broadcast_shapes_uncached(*shapes)
    170 if result_shape is None:
--> 171   raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
    172 return result_shape

ValueError: Incompatible shapes for broadcasting: shapes=[(), (13, 100, 1), (13, 13)]

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[2], line 53
     51 kernel = NUTS(model, max_tree_depth=10)
     52 mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
---> 53 mcmc.run(jax.random.PRNGKey(42), data)
     54 mcmc.print_summary()

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/infer/mcmc.py:644, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    642 map_args = (rng_key, init_state, init_params)
    643 if self.num_chains == 1:
--> 644     states_flat, last_state = partial_map_fn(map_args)
    645     states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    646 else:

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/infer/mcmc.py:426, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    424 # Check if _sample_fn is None, then we need to initialize the sampler.
    425 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 426     new_init_state = self.sampler.init(
    427         rng_key,
    428         self.num_warmup,
    429         init_params,
    430         model_args=args,
    431         model_kwargs=kwargs,
    432     )
    433     init_state = new_init_state if init_state is None else init_state
    434 sample_fn, postprocess_fn = self._get_cached_fns()

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/infer/hmc.py:743, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    738 # vectorized
    739 else:
    740     rng_key, rng_key_init_model = jnp.swapaxes(
    741         vmap(random.split)(rng_key), 0, 1
    742     )
--> 743 init_params = self._init_state(
    744     rng_key_init_model, model_args, model_kwargs, init_params
    745 )
    746 if self._potential_fn and init_params is None:
    747     raise ValueError(
    748         "Valid value of `init_params` must be provided with" " `potential_fn`."
    749     )

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/infer/hmc.py:687, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
    680 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    681     if self._model is not None:
    682         (
    683             new_init_params,
    684             potential_fn,
    685             postprocess_fn,
    686             model_trace,
--> 687         ) = initialize_model(
    688             rng_key,
    689             self._model,
    690             dynamic_args=True,
    691             init_strategy=self._init_strategy,
    692             model_args=model_args,
    693             model_kwargs=model_kwargs,
    694             forward_mode_differentiation=self._forward_mode_differentiation,
    695         )
    696         if init_params is None:
    697             init_params = new_init_params

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/infer/util.py:656, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    646 model_kwargs = {} if model_kwargs is None else model_kwargs
    647 substituted_model = substitute(
    648     seed(model, rng_key if is_prng_key(rng_key) else rng_key[0]),
    649     substitute_fn=init_strategy,
    650 )
    651 (
    652     inv_transforms,
    653     replay_model,
    654     has_enumerate_support,
    655     model_trace,
--> 656 ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
    657 # substitute param sites from model_trace to model so
    658 # we don't need to generate again parameters of `numpyro.module`
    659 model = substitute(
    660     model,
    661     data={
   (...)
    665     },
    666 )

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/infer/util.py:450, in _get_model_transforms(model, model_args, model_kwargs)
    448 def _get_model_transforms(model, model_args=(), model_kwargs=None):
    449     model_kwargs = {} if model_kwargs is None else model_kwargs
--> 450     model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    451     inv_transforms = {}
    452     # model code may need to be replayed in the presence of deterministic sites

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     """
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     """
--> 171     self(*args, **kwargs)
    172     return self.trace

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

    [... skipping similar frames: Messenger.__call__ at line 105 (1 times)]

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

Cell In[2], line 44, in model(data)
     40 with numpyro.plate("data", N):
     41     assignment = numpyro.sample("assignment", dist.Categorical(weights))
     42     numpyro.sample(
     43         "obs",
---> 44         dist.MultivariateNormal(locs[:, assignment], scale * jnp.eye(D, D)),
     45         obs=data,
     46     )

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/distributions/distribution.py:99, in DistributionMeta.__call__(cls, *args, **kwargs)
     97     if result is not None:
     98         return result
---> 99 return super().__call__(*args, **kwargs)

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/distributions/continuous.py:1448, in MultivariateNormal.__init__(self, loc, covariance_matrix, precision_matrix, scale_tril, validate_args)
   1446 loc = loc[..., jnp.newaxis]
   1447 if covariance_matrix is not None:
-> 1448     loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix)
   1449     self.scale_tril = jnp.linalg.cholesky(self.covariance_matrix)
   1450 elif precision_matrix is not None:

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/distributions/util.py:308, in promote_shapes(shape, *args)
    306 else:
    307     shapes = [jnp.shape(arg) for arg in args]
--> 308     num_dims = len(lax.broadcast_shapes(shape, *shapes))
    309     return [
    310         _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
    311         for arg, s in zip(args, shapes)
    312     ]

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/jax/_src/lax/lax.py:151, in broadcast_shapes(*shapes)
    149   return _broadcast_shapes_cached(*shapes)
    150 except:
--> 151   return _broadcast_shapes_uncached(*shapes)

File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/jax/_src/lax/lax.py:171, in _broadcast_shapes_uncached(*shapes)
    169 result_shape = _try_broadcast_shapes(shape_list)
    170 if result_shape is None:
--> 171   raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
    172 return result_shape

ValueError: Incompatible shapes for broadcasting: shapes=[(), (13, 100, 1), (13, 13)]

I have ran the model without the second plate construct (and no data). The sampled locs parameter has the correct shape (N_dimensions, N_components) or (D, K). But somehow the second piece doesn’t work… Not obvious to me why the shapes aren’t matching up, as my knowledge of the plate construct is rather weak.

Would greatly appreciate any pointers.

Ah it looks like there is a working version of the high-dimensional GMMs in this issue: Mixture model with discrete data in Numpyro

I missed it in my search. The code in this thread is not accessing the correct dimensions created by the .plate() construct.