Hi again,
I’m trying to use the MixtureSameFamily function to implement a mixture model between two Multivariate Normal distributions, with the same loc
and slightly different scale_tril
. My model is set up below, but I’m struggling to resolve the errors I get (included at the end of the post).
MixtureSameFamily is quite new to NumPyro, so maybe this isn’t the best way to implement this kind of Mixture model. In principle, I’m trying to sample latent parameters from the distribution:
where \nu are the latent variables, \mu are the loc
values, L_{in} and L_{out} are the scale tril
values, and Q modulates the mixture. All four are also parameterised in the model, below.
The outlier part of the mixture is identical to the inlier, but the variance of the first index is inflated by a hyperparameter. To set up the model I followed the documentation here: tfp.substrates.jax.distributions.MixtureSameFamily | TensorFlow Probability
def model(nstars=None):
d = 4
# Inlier parameters
xi = npy.sample("xi", dist.LogNormal(jnp.log(1.0), 1.0))
L_l = npy.sample("L_l", dist.LKJCholesky(d, 1.0))
chol_in = xi[..., None] * L_l
# Outlier parameters
sigo = npy.sample("sig_out", dist.LogNormal(jnp.log(4), 1.0))
xi_out = npy.deterministic("xi_out", xi * jnp.array([sigo + 1., 1., 1., 1.]))
chol_out = xi_out[..., None] * L_l
# Hyperparameters
mu = npy.sample("mu", dist.Normal(means, 1.))
Q = npy.sample('Q', dist.Beta(9, 1))
with npy.plate("nstars", nstars):
# Generate the latent variables (ie. the truths) from the MV Normal
latents = npy.sample("latents", dist.MixtureSameFamily(
mixing_distribution = dist.Categorical(probs = [Q, 1-Q]),
component_distribution = dist.MultivariateNormal(
loc = jnp.array([mu, mu]),
scale_tril = jnp.array([chol_in, chol_out]))))
Thank you so much for your help!
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/var/folders/qh/nrsz4hq94kg510wh1srt1x0ry2q37w/T/ipykernel_75022/2554853933.py in <module>
----> 1 npy.render_model(model, model_args=(obs_data.values.T, err, means, spreads))
/Users/Anaconda/anaconda3/envs/rubrum/lib/python3.8/site-packages/numpyro-0.7.2-py3.8.egg/numpyro/contrib/render.py in render_model(model, model_args, model_kwargs, filename, render_distributions, num_tries)
313 :param int num_tries: Times to trace model to detect discrete -> continuous dependency.
314 """
--> 315 relations = get_model_relations(
316 model, model_args=model_args, model_kwargs=model_kwargs, num_tries=num_tries
317 )
/Users/Anaconda/anaconda3/envs/rubrum/lib/python3.8/site-packages/numpyro-0.7.2-py3.8.egg/numpyro/contrib/render.py in get_model_relations(model, model_args, model_kwargs, num_tries)
51 model_kwargs = model_kwargs or {}
52
---> 53 trace = handlers.trace(handlers.seed(model, 0)).get_trace(
54 *model_args, **model_kwargs
55 )
/Users/Anaconda/anaconda3/envs/rubrum/lib/python3.8/site-packages/numpyro-0.7.2-py3.8.egg/numpyro/handlers.py in get_trace(self, *args, **kwargs)
163 :return: `OrderedDict` containing the execution trace.
164 """
--> 165 self(*args, **kwargs)
166 return self.trace
167
/Users/Anaconda/anaconda3/envs/rubrum/lib/python3.8/site-packages/numpyro-0.7.2-py3.8.egg/numpyro/primitives.py in __call__(self, *args, **kwargs)
85 return self
86 with self:
---> 87 return self.fn(*args, **kwargs)
88
89
/Users/Anaconda/anaconda3/envs/rubrum/lib/python3.8/site-packages/numpyro-0.7.2-py3.8.egg/numpyro/primitives.py in __call__(self, *args, **kwargs)
85 return self
86 with self:
---> 87 return self.fn(*args, **kwargs)
88
89
/var/folders/qh/nrsz4hq94kg510wh1srt1x0ry2q37w/T/ipykernel_75022/2004502866.py in model(y, err, means, spreads)
19 latents = npy.sample("latents", dist.MixtureSameFamily(
20 mixture_distributions = dist.Categorical(probs = [Q, 1-Q]),
---> 21 components_distribution = dist.MultivariateNormal(
22 loc = (mu, mu),
23 scale_tril = (chol_in, chol_out))))
/Users/Anaconda/anaconda3/envs/rubrum/lib/python3.8/site-packages/numpyro-0.7.2-py3.8.egg/numpyro/distributions/distribution.py in __call__(cls, *args, **kwargs)
92 if result is not None:
93 return result
---> 94 return super().__call__(*args, **kwargs)
95
96 @property
/Users/Anaconda/anaconda3/envs/rubrum/lib/python3.8/site-packages/numpyro-0.7.2-py3.8.egg/numpyro/distributions/continuous.py in __init__(self, loc, covariance_matrix, precision_matrix, scale_tril, validate_args)
962 (loc,) = promote_shapes(loc, shape=(1,))
963 # temporary append a new axis to loc
--> 964 loc = loc[..., jnp.newaxis]
965 if covariance_matrix is not None:
966 loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix)
TypeError: tuple indices must be integers or slices, not tuple