Unable to fit arbitrary mixture model

Hello. I am trying to fit some data over a mixture model of Exponential and Normal distributions. For the sake of simplicity, the code below will only use 2 Normal distributions, as it’s the simplest way of showing how it doesn’t work.

In my example, the two Normal distributions have the same weight, one has parameters (30, 5), the second (60, 5). The data that I’m fitting on is a list of 30s and 60s, each number repeated 8 times. So a properly fitted mixture will have both distributions with the same weight, with unchanged means and smaller deviations.

The results that I’m getting instead are:

  • one distribution gets a large weight and dominates the results; the parameter var_dir has the value [0.50, 4.99]
  • the means for both distributions migrate towards the overall mean (35 and 47 respectively), while their scales increase (17 and 18 respectively)

I’ve checked the tutorial [1] and other forum posts [2] on Gaussian Mixture Models (which fit properly). In the model, they sample the data from a Normal distribution over a list of locations+scales. I can’t use the same approach here as I’m interested in working with different distributions.

Meanwhile, there’s little on arbitrary mixture models [3].

[1] Gaussian Mixture Model — Pyro Tutorials 1.8.4 documentation
[2] like Vectorizable mixture model - #2 by eb8680_2
[3] found this, but it doesn’t help me much Arbitrary mixture models and discrete latent variable enumeration

Full code:

import pyro
from pyro.distributions import Normal, HalfCauchy, Chi2, Dirichlet, Categorical
from pyro.infer import SVI, Trace_ELBO
import torch
from torch.distributions import constraints

pyro.enable_validation(True)
pyro.clear_param_store()
pyro.set_rng_seed(42)

tt = torch.tensor
data = tt([30, 60, 30, 60, 30, 60, 30, 60, 30, 60, 30, 60, 30, 60, 30, 60])


def model(data=data):
  loc = pyro.sample('loc', Normal(tt(30.0), tt(10.0)))
  scale = pyro.sample('scale', HalfCauchy(tt(5.)))

  loc2 = pyro.sample('loc2', Normal(tt(60.0), tt(10.0)))
  scale2 = pyro.sample('scale2', HalfCauchy(tt(5.)))

  weights = pyro.sample('weights', Dirichlet(tt([1., 1.])))
  assignment = pyro.sample('assignment', Categorical(weights))
  distribs = [Normal(loc, scale), Normal(loc2, scale2)]

  with pyro.plate('plate', len(data)):
    pyro.sample('obs', distribs[assignment], obs=data)


def guide(data=data):
  var_loc = pyro.param('var_loc', tt(30.0))
  var_loc_scale = pyro.param('var_loc_scale', tt(2.), \
      constraint=constraints.positive)
  var_scale = pyro.param('var_scale', tt(5.))
  pyro.sample('loc', Normal(var_loc, var_loc_scale))
  pyro.sample('scale', Chi2(var_scale))

  var_loc2 = pyro.param('var_loc2', tt(60.0))
  var_loc2_scale = pyro.param('var_loc2_scale', tt(2.), \
      constraint=constraints.positive)
  var_scale2 = pyro.param('var_scale2', tt(5.))
  pyro.sample('loc2', Normal(var_loc2, var_loc2_scale))
  pyro.sample('scale2', Chi2(var_scale2))

  var_dir = pyro.param('var_dir', tt([1., 1.]), \
      constraint=constraints.interval(0.001, 5))
  weights = pyro.sample('weights', Dirichlet(var_dir))
  pyro.sample('assignment', Categorical(weights))
  

optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.95, 0.99]})
elbo = Trace_ELBO( max_plate_nesting=1, strict_enumeration_warning=True)
elbo.loss(model, config_enumerate(guide, "sequential"));
svi = SVI(model, guide, optim, loss=elbo)

for i in range(2001):
  svi.step(data)

  if i % 50 == 0:
    loc = pyro.param('var_loc')
    scale = pyro.param('var_scale')
    loc2 = pyro.param('var_loc2')
    scale2 = pyro.param('var_scale2')
    dirich = pyro.param('var_dir')

    print("%d %d -- %d %d -- %s" % \
        (loc, scale, loc2, scale2, dirich.data.tolist()))

Hi @aolariu, here are a few reasons your code is not working:

  1. The assignment variable is local to each datapoint in a mixture model, so it needs to be inside the data plate.
  2. You are not doing exact inference over assignments - compare to the Gaussian mixture model tutorial, which uses TraceEnum_ELBO rather than Trace_ELBO and does not have a guide sample site for assignment. Trace_ELBO does not support enumeration of any kind. See our enumeration tutorial for more background.
  3. You are using a Python list of component distributions, but that’s not compatible with Pyro’s parallel enumeration. As an alternative, consider using pyro.distributions.MaskedMixture to represent heterogeneous mixtures of two components.

I would suggest forking the Gaussian mixture model example and replacing the likelihood with a MaskedMixture distribution, since it already addresses the first two problems. You can see some example usage of MaskedMixture and parallel enumeration in the mixed-effect HMM example code.

Thank you @eb8680_2 for the very helpful suggestions! I implemented them and was able to get it working. Below is the working code, for whoever might stumble on this in the future:

import pyro
from pyro.distributions import Normal, HalfCauchy, Chi2, Dirichlet, \
    Categorical, MaskedMixture
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
import torch
from torch.distributions import constraints

pyro.enable_validation(True)
pyro.clear_param_store()
pyro.set_rng_seed(42)

tt = torch.tensor
data = tt([30, 60, 30, 60, 30, 60])


@config_enumerate
def model(data=data):
  loc = pyro.sample('loc', Normal(tt(30.0), tt(10.0)))
  scale = pyro.sample('scale', HalfCauchy(tt(5.)))

  loc2 = pyro.sample('loc2', Normal(tt(60.0), tt(10.0)))
  scale2 = pyro.sample('scale2', HalfCauchy(tt(5.)))

  weights = pyro.sample('weights', Dirichlet(tt([1., 1.])))

  with pyro.plate('plate', len(data)):
    assignment = pyro.sample('assignment', Categorical(weights)).bool()
    pyro.sample('obs', MaskedMixture(assignment, Normal(loc, scale), \
        Normal(loc2, scale2)), obs=data)


@config_enumerate
def guide(data=data):
  var_loc = pyro.param('var_loc', tt(30.0))
  var_loc_scale = pyro.param('var_loc_scale', tt(2.), \
      constraint=constraints.positive)
  var_scale = pyro.param('var_scale', tt(5.))
  pyro.sample('loc', Normal(var_loc, var_loc_scale))
  pyro.sample('scale', Chi2(var_scale))

  var_loc2 = pyro.param('var_loc2', tt(60.0))
  var_loc2_scale = pyro.param('var_loc2_scale', tt(2.), \
      constraint=constraints.positive)
  var_scale2 = pyro.param('var_scale2', tt(5.))
  pyro.sample('loc2', Normal(var_loc2, var_loc2_scale))
  pyro.sample('scale2', Chi2(var_scale2))

  var_dir = pyro.param('var_dir', tt([1., 1.]), \
      constraint=constraints.interval(0.001, 5))
  weights = pyro.sample('weights', Dirichlet(var_dir))
  

optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.95, 0.99]})
elbo = TraceEnum_ELBO( max_plate_nesting=1, strict_enumeration_warning=True)
elbo.loss(model, guide);
svi = SVI(model, guide, optim, loss=elbo)

for i in range(2001):
  svi.step(data)

  if i % 50 == 0:
    loc = pyro.param('var_loc')
    scale = pyro.param('var_scale')
    loc2 = pyro.param('var_loc2')
    scale2 = pyro.param('var_scale2')
    dirich = pyro.param('var_dir')

    print("%d %d -- %d %d -- %s" % \
        (loc, scale, loc2, scale2, dirich.data.tolist()))
1 Like