# Problem with enum in two dice inference

I have created a problem with two biased die (each with 6 faces and different biases). They are in a bowl. Each trial, I select a dice from a bowl (Bernoulli(0.5)), throw it and record the face it lands on. The objective is to infer the two biases. I enumerate over the Bernoulli in the model (variable “choose”). However, the system tells me that I am missing the “choose” variable in the guide. This should not be happening. Here is the code. If anybody can help, I would appreciate it. Thanks.

``````# Put two loaded die in a bowl. Choose one die (Bernoulli) and then throw it.
# The two die will have different biases.
# Objective: retrieve the two biases.

import sys, os
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist
#from IPython import embed

import torch
from pyro.infer import SVI, Trace_ELBO, config_enumerate, TraceEnum_ELBO
import torch.distributions.constraints as constraints
from torch.distributions.beta import Beta
from torch.distributions.categorical import Categorical
from torch.distributions.bernoulli import Bernoulli

#----------------------------------------------------------------------
# this is for running the notebook in our testing framework
n_svi_steps  = 50000
n_die_throws = 100
n_faces = 6
batch_size = 30

# enable validation (e.g. validate parameters of distributions)
pyro.enable_validation(True)

# clear the param store in case we're in a REPL
pyro.clear_param_store()

biases0 = torch.from_numpy(np.asarray([.03, -.03, 0., 0., 0., 0.]))
biases1 = torch.from_numpy(np.asarray([.00,  .00, 0., 0.02, 0., -0.02]))
samples = torch.zeros(n_die_throws)

#----------------------------------------------------------------------
def createData():
# Create biased die
alpha = .05   # bias

b = Bernoulli(0.5)
n0 = Categorical(biases0+1./6.)
n1 = Categorical(biases1+1./6.)

for i in range(n_die_throws):
choice = b.sample()
if choice:
sample = n1.sample()
else:
sample = n0.sample()

samples[i] = sample

return samples

#----------------------------------------------------------------------

#@config_enumerate(default="sequential") # more inefficient than parallel
def model(data):
# select a die
# choose is enumerated so should NOT appear in the guide.
# yet the system complains that "choose" is not in the guide. Why?
choose = pyro.sample("choose", dist.Bernoulli(0.5), infer={"enumerate": "sequential"})
prob0 = torch.ones(n_faces) * 10.  # More uniform for Dirichlet
prob1 = torch.ones(n_faces) * 10.  # More uniform for Dirichlet
f0 = pyro.sample("bias0", dist.Dirichlet(prob0))
f1 = pyro.sample("bias1", dist.Dirichlet(prob1))

with pyro.plate("z_minibatch", len(data), batch_size) as ind:
f = choose*f1 + (1-choose)*f0
pyro.sample("obs", dist.Categorical(f), obs=data[ind])
#----------------------------------------------------------------------

def guide(data):
prob0 = pyro.param("probs0", torch.ones(n_faces) * 10., constraint=constraints.positive)
prob1 = pyro.param("probs1", torch.ones(n_faces) * 10., constraint=constraints.positive)
pyro.sample("bias0", dist.Dirichlet(prob0))
pyro.sample("bias1", dist.Dirichlet(prob1))

#----------------------------------------------------------------------
data = createData()

# setup the optimizer
adam_params = {"lr": 0.001, "betas": (0.90, 0.999)}

# setup the inference algorithm
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, guide, optimizer, loss=elbo) # orig

for step in range(n_svi_steps):
loss = svi.step(data)
sys.exit()

if step % 100 == 0:
print("loss= ", step, loss)

print("Die 0 probs: ", pyro.param("probs0"))
print("Die 1 probs: ", pyro.param("probs1"))``````

Anybody there?

I think you’ll need to change `infer={"enumerate": "sequential"}` to `infer={"enumerate": "parallel"}` in your model. You might find the GMM tutorial (Gaussian Mixture Model — Pyro Tutorials 1.8.4 documentation) useful for this example. Quoting from the tutorial:

When enumerating variables in the model, the variables must be enumerated in parallel and must not appear in the guide. Mathematically, guide-side enumeration simply reduces variance in a stochastic ELBO by enumerating all values, whereas model-side enumeration avoids an application of Jensen’s inequality by exactly marginalizing out a variable.

I don’t think the shapes would still align correctly, but something like the following should work (not tested for correctness), where you define a 2x6 “bias” vector and index into that directly:

``````    with pyro.plate("select_dice", 2):
f = pyro.sample("bias", dist.Dirichlet(torch.ones(n_faces) * 10))
with pyro.plate("z_minibatch", len(data), batch_size) as ind:
choose = pyro.sample("choose", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}).long()
pyro.sample("obs", dist.Categorical(f[choose]), obs=data[ind])
``````