Multidimensional Gaussian Mixture Model from data

Hi,

Pyro looks very interesting and promising, thank you.

is it possible to infer (to assess, to create) a multidimensional Gaussian Mixture Model just from the matrix of observation using Pyro?
(ideally the number of GMM modes should be also assessed from the data)

kind regards,
Valery

P.S. site sign-up and sign-in is like from 2005 – errors everywhere and activation email in gmail spam telling me that Google can’t ensure the email comes indeed from the sender stated in the email.

Hi, see the Gaussian mixture model tutorial for an introduction to mixture models in Pyro, and the Dirichlet process mixture model tutorial for an example of fitting a mixture model with an unknown number of components.

1 Like

many thanks for the hint!

“we do not parameterize the covariance matrices of the Gaussians, though this should be done when analyzing a real-world dataset for more flexibility”

arghhh… the most interesting part is cut out from Dirichlet process mixture model tutorial :frowning:

any hint on this, please?

You can just sample multiple scales along with the means. For the simpler model in the Gaussian mixture model tutorial, that means moving the scale variable inside the components plate and indexing it with the assignment in the likelihood:

weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
with pyro.plate('components', K):
    scales = pyro.sample('scale', dist.LogNormal(0., 2.))
    locs = pyro.sample('locs', dist.Normal(0., 10.))

with pyro.plate('data', len(data)):
    assignment = pyro.sample('assignment', dist.Categorical(weights))
    pyro.sample('obs', dist.Normal(locs[assignment], scales[assignment]), obs=data)

In the multivariate setting, you can do something similar using the LKJCorrCholesky distribution (docs, example) to define a prior distribution over covariance Cholesky factors for each mixture component.

1 Like

regarding “Dirichlet process mixture model tutorial”:

This section of the tutorial doesn’t seem to be reproducible anymore:

  0%|          | 0/1500 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_struct.py in compute_log_prob(self, site_filter)
215                     try:
--> 216                         log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
217                     except ValueError:

~usr/local/lib/python3.6/dist-packages/torch/distributions/poisson.py in log_prob(self, value)
 60         if self._validate_args:
---> 61             self._validate_sample(value)
 62         rate, value = broadcast_all(self.rate, value)

~usr/local/lib/python3.6/dist-packages/torch/distributions/distribution.py in _validate_sample(self, value)
252         if not self.support.check(value).all():
--> 253             raise ValueError('The value argument must be within the support')
254 

ValueError: The value argument must be within the support

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-9-4e79a15ea6fc> in <module>
 32 losses = []
 33 
---> 34 train(n_iter)
 35 
 36 samples = torch.arange(0, 300).type(torch.float)

<ipython-input-6-c5d9a36e9211> in train(num_iterations)
  7     pyro.clear_param_store()
  8     for j in tqdm(range(num_iterations)):
----> 9         loss = svi.step(data)
 10         losses.append(loss)
 11 

~usr/local/lib/python3.6/dist-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
126         # get loss and compute gradients
127         with poutine.trace(param_only=True) as param_capture:
--> 128             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
129 
130         params = set(site["value"].unconstrained()

~usr/local/lib/python3.6/dist-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
124         loss = 0.0
125         # grab a trace from the generator
--> 126         for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
127             loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
128             loss += loss_particle / self.num_particles

~usr/local/lib/python3.6/dist-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, args, kwargs)
168         else:
169             for i in range(self.num_particles):
--> 170                 yield self._get_trace(model, guide, args, kwargs)

~usr/local/lib/python3.6/dist-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
 51         """
 52         model_trace, guide_trace = get_importance_trace(
---> 53             "flat", self.max_plate_nesting, model, guide, args, kwargs)
 54         if is_validation_enabled():
 55             check_if_enumerated(guide_trace)

~usr/local/lib/python3.6/dist-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
 53     model_trace = prune_subsample_sites(model_trace)
 54 
---> 55     model_trace.compute_log_prob()
 56     guide_trace.compute_score_parts()
 57     if is_validation_enabled():

~usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_struct.py in compute_log_prob(self, site_filter)
219                         shapes = self.format_shapes(last_site=site["name"])
220                         raise ValueError("Error while computing log_prob at site '{}':\n{}\n{}"
--> 221                                          .format(name, exc_value, shapes)).with_traceback(traceback)
222                     site["unscaled_log_prob"] = log_p
223                     log_p = scale_and_mask(log_p, site["scale"], site["mask"])

~usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_struct.py in compute_log_prob(self, site_filter)
214                 if "log_prob" not in site:
215                     try:
--> 216                         log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
217                     except ValueError:
218                         _, exc_value, traceback = sys.exc_info()

~usr/local/lib/python3.6/dist-packages/torch/distributions/poisson.py in log_prob(self, value)
 59     def log_prob(self, value):
 60         if self._validate_args:
---> 61             self._validate_sample(value)
 62         rate, value = broadcast_all(self.rate, value)
 63         return (rate.log() * value) - rate - (value + 1).lgamma()

~usr/local/lib/python3.6/dist-packages/torch/distributions/distribution.py in _validate_sample(self, value)
251 
252         if not self.support.check(value).all():
--> 253             raise ValueError('The value argument must be within the support')
254 
255     def _get_checked_instance(self, cls, _instance=None):

ValueError: Error while computing log_prob at site 'obs':
The value argument must be within the support
Trace Shapes:      
 Param Sites:      
Sample Sites:      
beta dist  19 |
    value  19 |
 log_prob  19 |
  lambda dist  20 |
    value  20 |
 log_prob  20 |
   z dist 320 |
    value 320 |
 log_prob 320 |
 obs dist 320 |
    value 320 |

Hi All,

Forgive me, I am new to Pyro but I tried to implement the discussion above as I have a need for a multivariate mixture with a full covariance matrix. I thought the code would be too verbose so I included it in a gist here.

The model infers reasonably well however with two major problems:

  1. A very large number of iterations is required to get to a reasonable estimate, I suspect due to either floating point inference (had trouble with double()) as well as poor initialization.
  2. The LKJ prior passed to scale_tril appears to only learn the lower triangular covariance, which is what scale_tril is for right? How do we allow the upper triangular to be “learnable”?

If anyone has a chance to take a look that would be great.

1 Like