I have created a problem with two biased die (each with 6 faces and different biases). They are in a bowl. Each trial, I select a dice from a bowl (Bernoulli(0.5)), throw it and record the face it lands on. The objective is to infer the two biases. I enumerate over the Bernoulli in the model (variable “choose”). However, the system tells me that I am missing the “choose” variable in the guide. This should not be happening. Here is the code. If anybody can help, I would appreciate it. Thanks.
# Put two loaded die in a bowl. Choose one die (Bernoulli) and then throw it.
# The two die will have different biases.
# Objective: retrieve the two biases.
import sys, os
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist
#from IPython import embed
import torch
from pyro.infer import SVI, Trace_ELBO, config_enumerate, TraceEnum_ELBO
from pyro.optim import Adam
import torch.distributions.constraints as constraints
from torch.distributions.beta import Beta
from torch.distributions.categorical import Categorical
from torch.distributions.bernoulli import Bernoulli
#----------------------------------------------------------------------
# this is for running the notebook in our testing framework
n_svi_steps = 50000
n_die_throws = 100
n_faces = 6
batch_size = 30
# enable validation (e.g. validate parameters of distributions)
pyro.enable_validation(True)
# clear the param store in case we're in a REPL
pyro.clear_param_store()
biases0 = torch.from_numpy(np.asarray([.03, -.03, 0., 0., 0., 0.]))
biases1 = torch.from_numpy(np.asarray([.00, .00, 0., 0.02, 0., -0.02]))
samples = torch.zeros(n_die_throws)
#----------------------------------------------------------------------
def createData():
# Create biased die
alpha = .05 # bias
b = Bernoulli(0.5)
n0 = Categorical(biases0+1./6.)
n1 = Categorical(biases1+1./6.)
for i in range(n_die_throws):
choice = b.sample()
if choice:
sample = n1.sample()
else:
sample = n0.sample()
samples[i] = sample
return samples
#----------------------------------------------------------------------
#@config_enumerate(default="sequential") # more inefficient than parallel
def model(data):
# select a die
# choose is enumerated so should NOT appear in the guide.
# yet the system complains that "choose" is not in the guide. Why?
choose = pyro.sample("choose", dist.Bernoulli(0.5), infer={"enumerate": "sequential"})
prob0 = torch.ones(n_faces) * 10. # More uniform for Dirichlet
prob1 = torch.ones(n_faces) * 10. # More uniform for Dirichlet
f0 = pyro.sample("bias0", dist.Dirichlet(prob0))
f1 = pyro.sample("bias1", dist.Dirichlet(prob1))
with pyro.plate("z_minibatch", len(data), batch_size) as ind:
f = choose*f1 + (1-choose)*f0
pyro.sample("obs", dist.Categorical(f), obs=data[ind])
#----------------------------------------------------------------------
def guide(data):
prob0 = pyro.param("probs0", torch.ones(n_faces) * 10., constraint=constraints.positive)
prob1 = pyro.param("probs1", torch.ones(n_faces) * 10., constraint=constraints.positive)
pyro.sample("bias0", dist.Dirichlet(prob0))
pyro.sample("bias1", dist.Dirichlet(prob1))
#----------------------------------------------------------------------
data = createData()
# setup the optimizer
adam_params = {"lr": 0.001, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)
# setup the inference algorithm
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, guide, optimizer, loss=elbo) # orig
# do gradient steps
for step in range(n_svi_steps):
loss = svi.step(data)
sys.exit()
if step % 100 == 0:
print("loss= ", step, loss)
print("Die 0 probs: ", pyro.param("probs0"))
print("Die 1 probs: ", pyro.param("probs1"))