Building gaussian mixture with Relaxed Bernoulli/Categorical

Hi, I am trying to implement a 1-d Gaussian Mixture of two components with Relaxed Bernoulli (binary case of concrete/gumbel-softmax) as the variational posterior.
However, the model cannot converge to the correct result:

p = 0.6
n_sample = 1000
mask = dist.Bernoulli(probs=p).sample((n_sample,))
loc1, loc2 = -6.0, 3.0
scale = 0.5
data = dist.MaskedMixture(mask.bool(),
                         dist.Normal(loc1, scale),
                         dist.Normal(loc2, scale)).sample()

def model(data):
    weights = pyro.param('weights', torch.tensor(0.5))
    locs = pyro.param('locs', torch.randn(2,))
    with pyro.plate('data', len(data)):
        assignment = pyro.sample('assignment', dist.Bernoulli(weights)).long()
        pyro.sample('obs', dist.Normal(locs[assignment], 1.0), obs=data)

T = 0.5
def guide(data):
    with pyro.plate('data', len(data)):
        alpha = pyro.param('alpha', torch.rand(len(data)), constraints.unit_interval)
        pyro.sample('assignment', dist.RelaxedBernoulliStraightThrough(torch.tensor(T), probs=alpha))
        
def train(data, svi, num_iterations):
    losses = []
    pyro.clear_param_store()
    for j in tqdm(range(num_iterations)):
        loss = svi.step(data)
        losses.append(loss)
    return losses

def initialize(seed, data, model, guide, optim):
    pyro.set_rng_seed(seed)
    pyro.clear_param_store()
    svi = SVI(model, guide, optim, Trace_ELBO(num_particles=50))
    return svi.loss(model, guide, data)

n_iter = 500
pyro.clear_param_store()
optim = Adam({'lr': 0.1, 'betas': [0.9, 0.99]})
loss, seed = min(
    [(initialize(seed, data, model, guide, optim),seed) for seed in range(100)]
)
pyro.set_rng_seed(seed)
svi = SVI(model, guide, optim, loss=Trace_ELBO(num_particles=50))
losses = train(data, svi, n_iter)

pyro.param('locs')

Out[50]:

tensor([-0.9745, -0.4087], requires_grad=True)

Is there any way for debugging this model or solving this issue? ( I believe this is caused by local minima problem if my implementation is correct.)

1 Like

have you tried a lower temperature T?

Yes, I tried that, but it did not help. (To reduce the gradient variance caused by low temperature, I also increase num_particles to 100)

I also tried to initialize the locs parameter to the ground truth mean [-6.0, 3.0]. However, they both shrinkage to somewhere around -0.5 after 1000 iterations:

p.s. Very interestingly, I cannot find any Gumbel Softmax based implementation of mixture model on the Internet.

it may be that RelaxedBernoulliStraightThrough is buggy and/or numerically unstable. this distribution hasn’t seen much usage afaik. have you looked at the implementation?

The implementation looks good to me.

In addition, following the test case of one hot categorical: pyro/test_relaxed_straight_through.py at b31963692e176a5099027dd4837c8a4cfe673a75 · pyro-ppl/pyro · GitHub)

I ran the code below

pyro.clear_param_store()
def model():
    p = torch.tensor([0.8])
    pyro.sample('z', Bernoulli(probs=p))

def guide():
    q = pyro.param('q', torch.tensor([0.4]), constraint=constraints.unit_interval)
    temp = torch.tensor(0.05)
    pyro.sample('z', RelaxedBernoulliStraightThrough(temperature=temp, probs=q))

adam = optim.Adam({"lr": 0.1, "betas": (0.95, 0.999)})
svi = SVI(model, guide, adam, loss=Trace_ELBO(num_particles=100, vectorize_particles=True))

losses = []
for k in range(6000):
    loss = svi.step()
    losses.append(loss)
    
print(pyro.param('q'))

# Output: tensor([0.4520], grad_fn=<ClampBackward>)

Clearly, this “test case” failed.

pyro.clear_param_store()
def model(T):
    p = torch.tensor([0.8])
    pyro.sample('z', RelaxedBernoulli(temp, p))

def guide(T):
    q = pyro.param('q', torch.tensor([0.4]), constraint=constraints.unit_interval)
    temp = torch.tensor(T)
    pyro.sample('z', RelaxedBernoulli(temperature=temp, probs=q))

adam = optim.Adam({"lr": 0.001, "betas": (0.95, 0.999)})
svi = SVI(model, guide, adam, loss=Trace_ELBO(num_particles=100, vectorize_particles=True))

losses = []
T = 1.0
for k in range(6000):
    loss = svi.step(T)
    T = max(0.5, T * (0.999 ** k))
    losses.append(loss)
    
print(pyro.param('q'))

# Output: tensor([0.8006], grad_fn=<ClampBackward>)

Model with RB as both prior/posterior works fine.

@xidulu ah yeah that’s interesting and perhaps makes sense. i don’t recall what the original references like this one do, i.e. whether they make the replacement only on the guide side or also on the model side (using pyro language). do you know?

Aha, that’s a tricky question:

The “C.2 WHAT YOU MIGHT RELAX AND WHY” section(page 15.) from the Concrete paper https://arxiv.org/pdf/1611.00712.pdf actually discussed different choices of model/prior. (relaxed or not). Their final choice is to use relaxed Bernoulli/Categorical on both model side and guide side. (Meanwhile, they use a trick to acquire a stable evaluation of the kl term in the ELBO).

In the Gumbel-softmax paper, they use un-relaxed prior and relaxed posterior.

This does not seem to be a big deal when training a VAE with discrete latent space, the network will converge anyway. (also VAE does not has a ground truth to recover). However, when it comes to SVI (non-amortized) with Pyro, it seems that the choice should be made carefully.

this sort of works. i think this sort of model and inference approach can lead to a nasty optimization problem with lots of bad local optima. although maybe there’s something else going on as well.

import pyro
import pyro.distributions as dist
import torch
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
from torch.distributions import constraints
import numpy as np


n_sample = 1000
mask = dist.Bernoulli(probs=0.6).sample((n_sample,))
loc1, loc2 = -2.0, 1.0
scale = 0.75
data = dist.MaskedMixture(mask.bool(),
                         dist.Normal(loc1, scale),
                         dist.Normal(loc2, scale)).sample()

temperature = torch.tensor(0.1)

def model(data):
    weight = pyro.param('weight', torch.tensor(0.5), constraint=constraints.unit_interval)
    locs = pyro.param('locs')
    scale = pyro.param('scale', torch.tensor(1.0), constraint=constraints.positive)
    with pyro.plate('data', len(data)):
        assignment = pyro.sample('assignment', dist.RelaxedBernoulliStraightThrough(temperature, probs=weight))
        pyro.sample('obs', dist.Normal(locs[assignment.long()], scale), obs=data)

def guide(data):
    with pyro.plate('data', len(data)):
        alpha = pyro.param('alpha')
        pyro.sample('assignment', dist.RelaxedBernoulliStraightThrough(temperature, probs=alpha))

def train(data, svi, num_iterations):
    losses = []

    locs = pyro.param('locs', torch.randn(2))
    # custom init strategy
    alpha_init = 0.1 + 0.8 * torch.argmin((data - locs.unsqueeze(-1)).abs(), dim=0)
    pyro.param('alpha', alpha_init, constraint=constraints.unit_interval)

    for j in range(num_iterations):
        loss = svi.step(data)
        if j % 100 == 0 or j == num_iterations - 1:
            running_loss = np.mean(losses[-50:]) if j > 0 else 0.0
            s = "[iter %d] loss:  %.4f %.4f  weight: %.3f scale: %.3f locs: %.3f %.3f"
            loc1, loc2 = pyro.param('locs').data.cpu().numpy()
            print(s % (j, loss, running_loss, pyro.param('weight').item(), pyro.param('scale').item(), loc1, loc2))
        losses.append(loss)
    return losses

n_iter = 1500
pyro.clear_param_store()
optim = Adam({'lr': 0.01, 'betas': [0.90, 0.99]})
pyro.set_rng_seed(1)
svi = SVI(model, guide, optim, loss=Trace_ELBO(num_particles=8))
losses = train(data, svi, n_iter)

print("data[:4]", data[:8].data.cpu().numpy())
print("alpha[:4]", pyro.param('alpha').data.cpu().numpy()[:8])
1 Like