I am generalizing the biased coin problem to a biased die. I assign probabilities to the die, run a 1000 samples, and try to infer the probabilities of landing on each face. The code appears to work, but convergence (via SVI) is very slow and there is a lot of volatility in the ELBO. Some suggestions for acceleration would be helpful. I am using a batch size of 50. I should mention that I am relatively new to Pyro. Here is the code. Any help is appreciated.
# based on the biased coin program in the Pyro tutorial
# (Inference in Pro: from stochastic functions to marginal distributions
# from which we solve the problem of a single loaded die. Given observations
# of N throws of a loaded dice, determine the bias along with confidence
# intervals. The prior will be a Dirichlet function.
import sys, os
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist
#from IPython import embed
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import torch
import torch.distributions.constraints as constraints
from torch.distributions.beta import Beta
from torch.distributions.categorical import Categorical
#----------------------------------------------------------------------
# this is for running the notebook in our testing framework
n_svi_steps = 10000
n_die_throws = 1000
n_faces = 6
# enable validation (e.g. validate parameters of distributions)
pyro.enable_validation(True)
# clear the param store in case we're in a REPL
pyro.clear_param_store()
def createData():
# Create biased die
alpha = .05 # bias
m = torch.ones(n_faces)*(1./n_faces)
m[0] -= alpha
m[1] += alpha
n = Categorical(m)
samples = n.sample([n_die_throws])
# the returned samples have shape: [n_die_throws, n_faces]
return samples
#----------------------------------------------------------------------
def model(data):
# Symmetric Dirichlet
batch_size = 10
prob = torch.ones(n_faces) * 1. # More uniform for Dirichlet
f = pyro.sample("latent_fairness", dist.Dirichlet(prob))
# loop over the observed data
with pyro.plate("z_minibatch", len(data), batch_size) as ind:
x = dist.Categorical(f)
s = pyro.sample("obs", dist.Categorical(f), obs=data[ind])
#----------------------------------------------------------------------
def guide(data):
batch_size = 10
prob = pyro.param("probs", torch.ones(n_faces) * 1., constraint=constraints.simplex)
f = pyro.sample("latent_fairness", dist.Dirichlet(prob))
#----------------------------------------------------------------------
data = createData()
# setup the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)
model(data) # works
# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) # orig
# do gradient steps
for step in range(n_svi_steps):
loss = svi.step(data)
if step % 100 == 0:
print("loss= ", step, loss)
print("Final probabilities: ", pyro.param("probs"))
print("Final probabilities: ", pyro.param("probs"))