 # Creating a mixture model with MixtureSameFamily

Hi again,

I’m trying to use the MixtureSameFamily function to implement a mixture model between two Multivariate Normal distributions, with the same loc and slightly different scale_tril. My model is set up below, but I’m struggling to resolve the errors I get (included at the end of the post).

MixtureSameFamily is quite new to NumPyro, so maybe this isn’t the best way to implement this kind of Mixture model. In principle, I’m trying to sample latent parameters from the distribution:

p(\nu | Q, \mu, L_{in}, L_{out}) = Q \mathcal{N}(\mu, L_{in}) + (1 - Q) \mathcal{N}(\mu, L_{out})

where \nu are the latent variables, \mu are the loc values, L_{in} and L_{out} are the scale tril values, and Q modulates the mixture. All four are also parameterised in the model, below.

The outlier part of the mixture is identical to the inlier, but the variance of the first index is inflated by a hyperparameter. To set up the model I followed the documentation here: tfp.substrates.jax.distributions.MixtureSameFamily

def model(nstars=None):
d = 4
# Inlier parameters
xi = npy.sample("xi", dist.LogNormal(jnp.log(1.0), 1.0))
L_l = npy.sample("L_l", dist.LKJCholesky(d, 1.0))
chol_in = xi[..., None] * L_l

# Outlier parameters
sigo = npy.sample("sig_out", dist.LogNormal(jnp.log(4), 1.0))
xi_out = npy.deterministic("xi_out", xi * jnp.array([sigo + 1., 1., 1., 1.]))
chol_out = xi_out[..., None] * L_l

# Hyperparameters
mu = npy.sample("mu", dist.Normal(means, 1.))
Q = npy.sample('Q', dist.Beta(9, 1))

with npy.plate("nstars", nstars):
# Generate the latent variables (ie. the truths) from the MV Normal
latents = npy.sample("latents", dist.MixtureSameFamily(
mixing_distribution = dist.Categorical(probs = [Q, 1-Q]),
component_distribution = dist.MultivariateNormal(
loc = jnp.array([mu, mu]),
scale_tril = jnp.array([chol_in, chol_out]))))


Thank you so much for your help!

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/var/folders/qh/nrsz4hq94kg510wh1srt1x0ry2q37w/T/ipykernel_75022/2554853933.py in <module>
----> 1 npy.render_model(model, model_args=(obs_data.values.T, err, means, spreads))

/Users/Anaconda/anaconda3/envs/rubrum/lib/python3.8/site-packages/numpyro-0.7.2-py3.8.egg/numpyro/contrib/render.py in render_model(model, model_args, model_kwargs, filename, render_distributions, num_tries)
313     :param int num_tries: Times to trace model to detect discrete -> continuous dependency.
314     """
--> 315     relations = get_model_relations(
316         model, model_args=model_args, model_kwargs=model_kwargs, num_tries=num_tries
317     )

/Users/Anaconda/anaconda3/envs/rubrum/lib/python3.8/site-packages/numpyro-0.7.2-py3.8.egg/numpyro/contrib/render.py in get_model_relations(model, model_args, model_kwargs, num_tries)
51     model_kwargs = model_kwargs or {}
52
---> 53     trace = handlers.trace(handlers.seed(model, 0)).get_trace(
54         *model_args, **model_kwargs
55     )

/Users/Anaconda/anaconda3/envs/rubrum/lib/python3.8/site-packages/numpyro-0.7.2-py3.8.egg/numpyro/handlers.py in get_trace(self, *args, **kwargs)
163         :return: OrderedDict containing the execution trace.
164         """
--> 165         self(*args, **kwargs)
166         return self.trace
167

/Users/Anaconda/anaconda3/envs/rubrum/lib/python3.8/site-packages/numpyro-0.7.2-py3.8.egg/numpyro/primitives.py in __call__(self, *args, **kwargs)
85             return self
86         with self:
---> 87             return self.fn(*args, **kwargs)
88
89

/Users/Anaconda/anaconda3/envs/rubrum/lib/python3.8/site-packages/numpyro-0.7.2-py3.8.egg/numpyro/primitives.py in __call__(self, *args, **kwargs)
85             return self
86         with self:
---> 87             return self.fn(*args, **kwargs)
88
89

/var/folders/qh/nrsz4hq94kg510wh1srt1x0ry2q37w/T/ipykernel_75022/2004502866.py in model(y, err, means, spreads)
19         latents = npy.sample("latents", dist.MixtureSameFamily(
20                     mixture_distributions = dist.Categorical(probs = [Q, 1-Q]),
---> 21                     components_distribution = dist.MultivariateNormal(
22                         loc = (mu, mu),
23                         scale_tril = (chol_in, chol_out))))

/Users/Anaconda/anaconda3/envs/rubrum/lib/python3.8/site-packages/numpyro-0.7.2-py3.8.egg/numpyro/distributions/distribution.py in __call__(cls, *args, **kwargs)
92             if result is not None:
93                 return result
---> 94         return super().__call__(*args, **kwargs)
95
96     @property

/Users/Anaconda/anaconda3/envs/rubrum/lib/python3.8/site-packages/numpyro-0.7.2-py3.8.egg/numpyro/distributions/continuous.py in __init__(self, loc, covariance_matrix, precision_matrix, scale_tril, validate_args)
962             (loc,) = promote_shapes(loc, shape=(1,))
963         # temporary append a new axis to loc
--> 964         loc = loc[..., jnp.newaxis]
965         if covariance_matrix is not None:
966             loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix)

TypeError: tuple indices must be integers or slices, not tuple


I was able to solve the problem myself (whoops) by making sure that all the inputs to MixtureSameFamily were in a JAX Numpy array! The working model is below, for anybody else experiencing the same issues:


def model(nstars = None):
d = 4
# Inlier parameters
L_l = npy.sample("L_l", dist.LKJCholesky(d, 1.0))
chol_in = xi[..., None] * L_l

# Outlier parameters
sigo = npy.sample("sig_out", dist.LogNormal(jnp.log(4), 1.0))
xi_out = npy.deterministic("xi_out", xi * jnp.array([sigo + 1., 1., 1., 1.]))
chol_out = xi_out[..., None] * L_l

# Hyperparameters
mu = npy.sample("mu", dist.Normal(means, 1.))
Q = npy.sample('Q', dist.Beta(9, 1))

with npy.plate("nstars", nstars):
# Generate the latent variables (ie. the truths) from the MV Normal
latents = npy.sample("latents", dist.MixtureSameFamily(
mixing_distribution = dist.Categorical(probs = jnp.array([Q, 1-Q])),
component_distribution = dist.MultivariateNormal(
loc = jnp.array([mu, mu]),
scale_tril = jnp.array([chol_in, chol_out]))))

2 Likes