Hi,
I was trying to expand the GMM tutorial to higher dimensions. I noticed there is a previous issue in the pyro context, but I don’t think the solution offered exists (.independent()
method no longer exists).
Here is a minimal reproducible example of my attempt so far:
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.funsor import config_enumerate, infer_discrete
from numpyro.infer import MCMC, NUTS
# Generate some random data
data = dist.Normal(0, 1).sample(jax.random.PRNGKey(42), (100, 13))
print(data.shape)
# Define NumPyro model for GMMs in high dimensions
@config_enumerate
def model(data):
# Constants
N = data.shape[0] # Number of datapoints
D = 13 # Number of input dimensions
K = 3 # Number of mixture components
# Global parameters
# TODO: make scale/covariance vary by component
weights = numpyro.sample("weights", dist.Dirichlet(0.5 * jnp.ones(K)))
scale = numpyro.sample("scale", dist.LogNormal(0.0, 2.0))
# Plate construct for mean vector of each gaussian component
with numpyro.plate("components", K):
locs = numpyro.sample(
"locs", dist.Normal(jnp.ones(shape=(D, 1)), 10.0)
) # Shape = DxK
# Plate construct to loop over data
with numpyro.plate("data", N):
assignment = numpyro.sample("assignment", dist.Categorical(weights))
numpyro.sample(
"obs",
dist.MultivariateNormal(locs[:, assignment], scale * jnp.eye(D, D)),
obs=data,
)
# Inference
kernel = NUTS(model, max_tree_depth=10)
mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
mcmc.run(jax.random.PRNGKey(42), data)
mcmc.print_summary()
The error that I’m getting is:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/jax/_src/lax/lax.py:149, in broadcast_shapes(*shapes)
148 try:
--> 149 return _broadcast_shapes_cached(*shapes)
150 except:
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/jax/_src/util.py:287, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
286 else:
--> 287 return cached(config.trace_context(), *args, **kwargs)
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/jax/_src/util.py:280, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
278 @functools.lru_cache(max_size)
279 def cached(_, *args, **kwargs):
--> 280 return f(*args, **kwargs)
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/jax/_src/lax/lax.py:155, in _broadcast_shapes_cached(*shapes)
153 @cache()
154 def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]:
--> 155 return _broadcast_shapes_uncached(*shapes)
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/jax/_src/lax/lax.py:171, in _broadcast_shapes_uncached(*shapes)
170 if result_shape is None:
--> 171 raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
172 return result_shape
ValueError: Incompatible shapes for broadcasting: shapes=[(), (13, 100, 1), (13, 13)]
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
Cell In[2], line 53
51 kernel = NUTS(model, max_tree_depth=10)
52 mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
---> 53 mcmc.run(jax.random.PRNGKey(42), data)
54 mcmc.print_summary()
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/infer/mcmc.py:644, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
642 map_args = (rng_key, init_state, init_params)
643 if self.num_chains == 1:
--> 644 states_flat, last_state = partial_map_fn(map_args)
645 states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
646 else:
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/infer/mcmc.py:426, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
424 # Check if _sample_fn is None, then we need to initialize the sampler.
425 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 426 new_init_state = self.sampler.init(
427 rng_key,
428 self.num_warmup,
429 init_params,
430 model_args=args,
431 model_kwargs=kwargs,
432 )
433 init_state = new_init_state if init_state is None else init_state
434 sample_fn, postprocess_fn = self._get_cached_fns()
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/infer/hmc.py:743, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
738 # vectorized
739 else:
740 rng_key, rng_key_init_model = jnp.swapaxes(
741 vmap(random.split)(rng_key), 0, 1
742 )
--> 743 init_params = self._init_state(
744 rng_key_init_model, model_args, model_kwargs, init_params
745 )
746 if self._potential_fn and init_params is None:
747 raise ValueError(
748 "Valid value of `init_params` must be provided with" " `potential_fn`."
749 )
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/infer/hmc.py:687, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
680 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
681 if self._model is not None:
682 (
683 new_init_params,
684 potential_fn,
685 postprocess_fn,
686 model_trace,
--> 687 ) = initialize_model(
688 rng_key,
689 self._model,
690 dynamic_args=True,
691 init_strategy=self._init_strategy,
692 model_args=model_args,
693 model_kwargs=model_kwargs,
694 forward_mode_differentiation=self._forward_mode_differentiation,
695 )
696 if init_params is None:
697 init_params = new_init_params
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/infer/util.py:656, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
646 model_kwargs = {} if model_kwargs is None else model_kwargs
647 substituted_model = substitute(
648 seed(model, rng_key if is_prng_key(rng_key) else rng_key[0]),
649 substitute_fn=init_strategy,
650 )
651 (
652 inv_transforms,
653 replay_model,
654 has_enumerate_support,
655 model_trace,
--> 656 ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
657 # substitute param sites from model_trace to model so
658 # we don't need to generate again parameters of `numpyro.module`
659 model = substitute(
660 model,
661 data={
(...)
665 },
666 )
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/infer/util.py:450, in _get_model_transforms(model, model_args, model_kwargs)
448 def _get_model_transforms(model, model_args=(), model_kwargs=None):
449 model_kwargs = {} if model_kwargs is None else model_kwargs
--> 450 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
451 inv_transforms = {}
452 # model code may need to be replayed in the presence of deterministic sites
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: `OrderedDict` containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
[... skipping similar frames: Messenger.__call__ at line 105 (1 times)]
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
Cell In[2], line 44, in model(data)
40 with numpyro.plate("data", N):
41 assignment = numpyro.sample("assignment", dist.Categorical(weights))
42 numpyro.sample(
43 "obs",
---> 44 dist.MultivariateNormal(locs[:, assignment], scale * jnp.eye(D, D)),
45 obs=data,
46 )
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/distributions/distribution.py:99, in DistributionMeta.__call__(cls, *args, **kwargs)
97 if result is not None:
98 return result
---> 99 return super().__call__(*args, **kwargs)
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/distributions/continuous.py:1448, in MultivariateNormal.__init__(self, loc, covariance_matrix, precision_matrix, scale_tril, validate_args)
1446 loc = loc[..., jnp.newaxis]
1447 if covariance_matrix is not None:
-> 1448 loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix)
1449 self.scale_tril = jnp.linalg.cholesky(self.covariance_matrix)
1450 elif precision_matrix is not None:
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/numpyro/distributions/util.py:308, in promote_shapes(shape, *args)
306 else:
307 shapes = [jnp.shape(arg) for arg in args]
--> 308 num_dims = len(lax.broadcast_shapes(shape, *shapes))
309 return [
310 _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
311 for arg, s in zip(args, shapes)
312 ]
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/jax/_src/lax/lax.py:151, in broadcast_shapes(*shapes)
149 return _broadcast_shapes_cached(*shapes)
150 except:
--> 151 return _broadcast_shapes_uncached(*shapes)
File ~/miniconda3/envs/globalcal/lib/python3.10/site-packages/jax/_src/lax/lax.py:171, in _broadcast_shapes_uncached(*shapes)
169 result_shape = _try_broadcast_shapes(shape_list)
170 if result_shape is None:
--> 171 raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
172 return result_shape
ValueError: Incompatible shapes for broadcasting: shapes=[(), (13, 100, 1), (13, 13)]
I have ran the model without the second plate construct (and no data). The sampled locs
parameter has the correct shape (N_dimensions, N_components) or (D, K). But somehow the second piece doesn’t work… Not obvious to me why the shapes aren’t matching up, as my knowledge of the plate construct is rather weak.
Would greatly appreciate any pointers.