Early stopping


I am running the fair coin tutorial (which I understand). I can duplicate the results of the tutorial: coin fairness = 0.53 ± 0.07 (the tutorial has a std of 0.09, but who is looking closely).
I am running Pyro 0.3 .

Instead of running 2000 iterations with svi.step, I implemented early stopping. If the minimum loss does not decrease within either 500 or 1000 iterations, I exit the loop, which saves a lot of time. I then print out the coin fairness at the minimum loss. When I do this, I typically get a fairness of 0.50 ± 0.07 . Without the early stopping, I duplicate the results in the tutorial.

Note that the bounds are [0.43, 0.57] which overlaps the results from the tutorial {0.46, 0.6]. But this is still somewhat disconcerting. One would expect the lower loss to be closer to the truth, especially in this simple example.

Here is my code if anybody wishes to play with it.

from __future__ import print_function
import sys
import math
import os
import torch
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist

# GE added early stopping (keeps track of loss)
# min loss does not always correspond to fairness = 0.53
# I sometimes get 0.502 (which is outside one standard deviation of 0.53 +- 0.09 (or +- 0.07)
# Should one do early stopping (to save time?)

# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 10000

# enable validation (e.g. validate parameters of distributions)
assert pyro.__version__.startswith('0.3.0')

ttensor = torch.tensor

# clear the param store in case we're in a REPL

# create some data with 6 observed heads and 4 observed tails
data = []
for _ in range(60):
for _ in range(40):

data = ttensor(data)

def model(data):
    # define the hyperparameters that control the beta prior
    # x ~ Beta(a, b), x in [0,1]
    alpha0 = ttensor(10.0)
    beta0 = ttensor(10.0)
    # sample f from the beta prior
    # it does not look like the arguments to dist.Beta need be tensors if only scalars)
    #f1 = pyro.sample("llatent_fairness", dist.Beta(alpha0, beta0))
    f = pyro.sample("latent_fairness", dist.Beta(10., 10.))
    # loop over the observed data

    # adding a plate in the model and not in the guide does not crash the program. Why? 
    # fairness of coin: 0.55
    # With the plate, the time goes from 9.3sec to 4.9 sec (50% cut in time). 
	# Likelihood: p(data | f), or Joint: p(data, f)
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Bernoulli(f), obs=data)

def guide(data):
    # register the two variational parameters with Pyro
    # - both parameters will have initial value 15.0.
    # - because we invoke constraints.positive, the optimizer
    # will take gradients on the unconstrained parameters
    # (which are related to the constrained parameters by a log)
	# Posterior: p(data | f). ttensor required
    alpha_q = pyro.param("alpha_q", ttensor(15.0), constraint=constraints.positive)
    beta_q  = pyro.param("beta_q",  ttensor(15.0), constraint=constraints.positive)
    # sample latent_fairness from the distribution Beta(alpha_q, beta_q)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

# setup the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

# do gradient steps
print("n_steps= ", n_steps)
min_loss = 1.e8
last_step = 0
for step in range(n_steps):
    loss = svi.step(data)
    if loss < min_loss:
    #if 1:
        min_loss = loss

        #if step % 100 == 0:
        print("loss= ", loss)
        #print('.', end='')
        alpha_q = pyro.param("alpha_q").item()
        beta_q = pyro.param("beta_q").item()
        #print("alpha, beta= ", alpha_q, beta_q)
        inferred_mean_opt = alpha_q / (alpha_q + beta_q)
        factor_opt = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
        inferred_std_opt = inferred_mean_opt * math.sqrt(factor_opt)
        print("\nfairness: %.3f +- %.3f" % (inferred_mean_opt, inferred_std_opt))

        last_step = step

    if (step - last_step) > 1000:
        print("BREAK: step, last_step= ", step, last_step)

print("Last min_loss update: last_step= ", last_step)

# grab the learned variational parameters
alpha_q = pyro.param("alpha_q").item()
beta_q = pyro.param("beta_q").item()

# here we use some facts about the beta distribution
# compute the inferred mean of the coin's fairness
inferred_mean = alpha_q / (alpha_q + beta_q)
# compute inferred standard deviation
factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
inferred_std = inferred_mean * math.sqrt(factor)

print("\nbased on the data and our prior belief, the fairness " +
      "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std))
print("\nAt min loss: fairness: %.3f +- %.3f" % (inferred_mean_opt, inferred_std_opt))

I’m not sure if a lower - ELBO corresponds to being closer to the truth. We definitely know that KL is zero when ELBO is maximized. But can we say that for other non-optimal values?

I raise this question because in general KL is not a metric and relative comparison may not be possible. Correct me if I am wrong in understanding somewhere.

1 Like

You are correct! I forgot that KL was not a metric. So that begs the question: how does one know how many iterations to run SVI before stopping? Are there any early stopping criteria that researchers have considered? Thanks.

there are no easy answers here. one thing you can do is use some validation data distinct from your training data and then monitor the log likelihood on the validation set.