Biased die problem

I am generalizing the biased coin problem to a biased die. I assign probabilities to the die, run a 1000 samples, and try to infer the probabilities of landing on each face. The code appears to work, but convergence (via SVI) is very slow and there is a lot of volatility in the ELBO. Some suggestions for acceleration would be helpful. I am using a batch size of 50. I should mention that I am relatively new to Pyro. Here is the code. Any help is appreciated.

# based on the biased coin program in the Pyro tutorial
# (Inference in Pro: from stochastic functions to marginal distributions
# from which we solve the problem of a single loaded die. Given observations
# of N throws of a loaded dice, determine the bias along with confidence
# intervals. The prior will be a Dirichlet function.


import sys, os
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist
#from IPython import embed

from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import torch
import torch.distributions.constraints as constraints
from torch.distributions.beta import Beta
from torch.distributions.categorical import Categorical

#----------------------------------------------------------------------
# this is for running the notebook in our testing framework
n_svi_steps  = 10000
n_die_throws = 1000
n_faces = 6

# enable validation (e.g. validate parameters of distributions)
pyro.enable_validation(True)

# clear the param store in case we're in a REPL
pyro.clear_param_store()

def createData():
    # Create biased die
    alpha = .05   # bias

    m = torch.ones(n_faces)*(1./n_faces)
    m[0] -= alpha
    m[1] += alpha
    n = Categorical(m)
    samples = n.sample([n_die_throws])
    # the returned samples have shape: [n_die_throws, n_faces]
    return samples

#----------------------------------------------------------------------

def model(data):
    # Symmetric Dirichlet
    batch_size = 10
    prob = torch.ones(n_faces) * 1.  # More uniform for Dirichlet
    f = pyro.sample("latent_fairness", dist.Dirichlet(prob))

    # loop over the observed data
    with pyro.plate("z_minibatch", len(data), batch_size) as ind:
        x = dist.Categorical(f)
        s = pyro.sample("obs", dist.Categorical(f), obs=data[ind])
#----------------------------------------------------------------------

def guide(data):
    batch_size = 10
    prob = pyro.param("probs", torch.ones(n_faces) * 1., constraint=constraints.simplex)
    f = pyro.sample("latent_fairness", dist.Dirichlet(prob))

#----------------------------------------------------------------------
data = createData()

# setup the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

model(data) # works

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) # orig

# do gradient steps
for step in range(n_svi_steps):
    loss = svi.step(data)

    if step % 100 == 0:
        print("loss= ", step, loss)
        print("Final probabilities: ", pyro.param("probs"))

print("Final probabilities: ", pyro.param("probs"))

@erlebach What you have looks correct. The Dirichlet distribution may have numerical stability issues, so you might want to try the ClippedAdam optimizer. Also you can try a bigger batch size.

Note that you could also use progressively clever model-specific tricks, but these won’t be useful in general models:

  • use a Multinomial instead of Categorical
  • use a DirichletMultinomial instead of a Dirichlet + Multinomial (in which case you can use an empty guide)
  • simply hand-compute the posterior :smile: posterior_concentration = 1 + data.sum(-2)

Thanks! I do not understand most of what you wrote but will certainly take a look.

 Gordon