Mixture Model with TraceGraph_ELBO doesn't work

Hi, There is an example of Gaussian Mixture model in Pyro using TraceEnum_ELBO(Gaussian Mixture Model — Pyro Tutorials 1.8.1 documentation). I understand the idea is to use enumeration to marginalise the discrete variable called “assignment”.

Apart from this, I also notice that TraceGraph_ELBO can also deal with discrete variable and it seems that it doesn’t need the enumeration strategy. Therefore I want to try TraceGraph_ELBO instead. However, It seems that I failed to get the correct result. I’m assuming that one possible reason can be that I didn’t write the guide() correctly.

Below are my code and result. Could you please give me some advice? Thank you!

import torch
import os
import numpy as np
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
from pyro.optim import Adam
from pyro.infer import SVI, TraceGraph_ELBO

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.1')
from pyro.distributions.torch import Multinomial, Beta, Dirichlet, Beta, Categorical, MultivariateNormal, Uniform

K = 2  # Fixed number of components.
data = torch.tensor([0., 1., 10., 11., 12.])


def model(data):
    # Global variables.
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
    scale = pyro.sample('scale', dist.LogNormal(0., 2.))
    with pyro.plate('components', K):
        locs = pyro.sample('locs', dist.Normal(0., 10.))

    with pyro.plate('data', len(data)):
        # Local variables.
        assignment = pyro.sample('assignment', dist.Categorical(weights))
        pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

def guide(data):
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))

    # scale
    scale_para_loc = pyro.param('scale_para_loc', torch.tensor(1.),constraint=constraints.positive)
    scale = pyro.sample('scale', dist.LogNormal(scale_para_loc, 2.))

    # locs
    locs_para_loc = pyro.param('locs_para_loc', torch.tensor(5.))
    locs_para_scale = pyro.param('locs_para_scale', torch.tensor(1.),constraint=constraints.positive)
    with pyro.plate('components', K):
        locs = pyro.sample('locs', dist.Normal(locs_para_loc, locs_para_scale))
        
    with pyro.plate('data', len(data)):
        # Local variables.
        assignment = pyro.sample('assignment', dist.Categorical(weights))

pyro.clear_param_store()
optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceGraph_ELBO()
svi = SVI(model, guide, optim, loss=elbo)

losses = []
for i in range(501 if not smoke_test else 2):
    loss = svi.step(data)
    losses.append(loss)
    if i % 10 == 0:
        print("ELBO at iter i = "+str(i),loss)

from pyro.infer import Predictive, SVI, Trace_ELBO, TraceGraph_ELBO
num_samples = 10
predictive = Predictive(model, guide=guide, num_samples=num_samples)
graph_samples = predictive(data)
graph_samples['assignment']

However, the ELBO doesn’t convergence at all. And the estimation results of the variable “assignment” doesn’t make any sense. Below are the 10 samples of “assignment” :
tensor(
[[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[0, 1, 1, 0, 1],
[0, 0, 1, 0, 1],
[1, 0, 0, 0, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]])

The samples of “locs” also looks wrong:
tensor(
[[4.0562, 5.4942],
[4.2813, 6.4910],
[9.0555, 2.7641],
[6.8727, 6.9697],
[8.4839, 0.8544],
[2.1208, 8.5360],
[6.4476, 7.4402],
[5.7618, 3.1934],
[2.0187, 4.2319],
[2.9078, 4.8222]])

Thank you in advance for your help!

didn’t look at your code in detail but as explained in this tutorial elbo gradient estimators in the presence of discrete latent variables can be high variance. as a consequence this choice
{'lr': 0.1, 'betas': [0.8, 0.99]})
is unlikely to be a good one. something like
{'lr': 0.003, 'betas': [0.9, 0.999]})
has a better chance of succeeding and will likely require more than a mere 501 steps.

Thank you so much for your quick reply.
I’ve changed the parameters to {‘lr’: 0.003, ‘betas’: [0.9, 0.999]} and ran 10000 iterations. However the loss fluctuates and has no trend to decrease, as shown below.

The model() is exactly the same as what’s mentioned in Pyro documentation: Gaussian Mixture Model — Pyro Tutorials 1.8.1 documentation. The only difference is the documentation uses Autoguide while I defined the guide() myself because the Autoguide doesn’t allow the existence of discrete variables, while I need to include discrete variables in the guide since I used TraceGraph_ELBO and doesn’t use enumeration.

Therefore I think the problem is I wrote guide() incorrectly. Could you please help me to check what’s wrong with my guide()? Thank you so much!

your guide is far too inflexible.

  • your variational distribution for weights just sets it equal to the prior.
  • you set the variance of scale to be a large fixed number instead of a learnable parameter
  • your locs is parameterized as a single scalar instead of something of dimension K

etc etc

Thank you so much for your help. I’m so sorry to bother you again. I’ve modified the guide by:

  1. Change the parameters of the variable “locs” into K dimension by using pyro.param('locs_para_loc', lambda: torch.ones(K) / K)

  2. Set the variance of scale to a learnable parameter

  3. Rewrite the distribution of the “weights” variable (However I don’t think I’m correct)

The full code for the guide() is shown below:

def guide(data):

    # scale
    scale_para_loc = pyro.param('scale_para_loc', torch.tensor(1.),constraint=constraints.positive)
    scale_para_scale = pyro.param('scale_para_scale', torch.tensor(1.),constraint=constraints.positive)
    scale = pyro.sample('scale', dist.LogNormal(scale_para_loc, scale_para_scale))

    # locs
    locs_para_loc = pyro.param('locs_para_loc', lambda: torch.ones(K) / K,constraint=constraints.positive)
    locs_para_scale = pyro.param('locs_para_scale', lambda: torch.ones(K) / K,constraint=constraints.positive)

    with pyro.plate('components', K):
        locs = pyro.sample('locs', dist.Normal(locs_para_loc, locs_para_scale))
    
    weights_para = pyro.param("weights_para", lambda: torch.ones(K) / K, constraint=constraints.simplex)
    weights = pyro.sample('weights', dist.Dirichlet(weights_para))
     
    with pyro.plate('data', len(data)):
        # Local variables.
        assignment = pyro.sample('assignment', dist.Categorical(weights))

However, the TraceGraph_ELBO still fluctuates and the result makes no sense.

I’m quite new to pyro and tried my best to write the guide. I’m so sorry that I repeated asking you for help. If possible, could you please help me to correct the guide? A correct guide will be a very good example for me to learn how to self-define the guide. Thank you so much!

well something like this kind of works. note however that this is a terrible way to be attempting to solve this problem. discrete latent variables and black box variational inference generally don’t get along together very well. which is why enumeration is preferred whenever possible.

def guide(data):
    dirichlet_param = pyro.param("dirichlet_param", torch.ones(K) / K,
                                 constraint=constraints.simplex)
    weights = pyro.sample('weights', dist.Dirichlet(dirichlet_param))

    scale_para_loc = pyro.param('scale_para_loc', torch.tensor(0.))
    scale_para_scale = pyro.param('scale_para_scale', torch.tensor(0.001),
                                  constraint=constraints.positive)
    scale = pyro.sample('scale', dist.LogNormal(scale_para_loc, scale_para_scale))

    locs_para_loc = pyro.param('locs_para_loc', torch.tensor([0.0, 10.0]))
    locs_para_scale = pyro.param('locs_para_scale', 0.001 * torch.ones(K),
                                 constraint=constraints.positive)
    with pyro.plate('components', K):
        locs = pyro.sample('locs', dist.Normal(locs_para_loc, locs_para_scale))

    with pyro.plate('data', len(data)):
        logits = pyro.param("logits", 0.1 * torch.randn(len(data), K))
        assignment = pyro.sample('assignment', dist.Categorical(logits=logits))
1 Like