Problem with enum in two dice inference

I have created a problem with two biased die (each with 6 faces and different biases). They are in a bowl. Each trial, I select a dice from a bowl (Bernoulli(0.5)), throw it and record the face it lands on. The objective is to infer the two biases. I enumerate over the Bernoulli in the model (variable “choose”). However, the system tells me that I am missing the “choose” variable in the guide. This should not be happening. Here is the code. If anybody can help, I would appreciate it. Thanks.

# Put two loaded die in a bowl. Choose one die (Bernoulli) and then throw it.
# The two die will have different biases.
# Objective: retrieve the two biases.

import sys, os
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist
#from IPython import embed

import torch
from pyro.infer import SVI, Trace_ELBO, config_enumerate, TraceEnum_ELBO
from pyro.optim import Adam
import torch.distributions.constraints as constraints
from torch.distributions.beta import Beta
from torch.distributions.categorical import Categorical
from torch.distributions.bernoulli import Bernoulli

# this is for running the notebook in our testing framework
n_svi_steps  = 50000
n_die_throws = 100
n_faces = 6
batch_size = 30

# enable validation (e.g. validate parameters of distributions)

# clear the param store in case we're in a REPL

biases0 = torch.from_numpy(np.asarray([.03, -.03, 0., 0., 0., 0.]))
biases1 = torch.from_numpy(np.asarray([.00,  .00, 0., 0.02, 0., -0.02]))
samples = torch.zeros(n_die_throws)

def createData():
    # Create biased die
    alpha = .05   # bias

    b = Bernoulli(0.5)
    n0 = Categorical(biases0+1./6.)
    n1 = Categorical(biases1+1./6.)

    for i in range(n_die_throws):
       choice = b.sample()
       if choice:
             sample = n1.sample()
             sample = n0.sample()

       samples[i] = sample

    return samples


#@config_enumerate(default="sequential") # more inefficient than parallel
def model(data):
    # select a die
	# choose is enumerated so should NOT appear in the guide.
	# yet the system complains that "choose" is not in the guide. Why?
    choose = pyro.sample("choose", dist.Bernoulli(0.5), infer={"enumerate": "sequential"})
    prob0 = torch.ones(n_faces) * 10.  # More uniform for Dirichlet
    prob1 = torch.ones(n_faces) * 10.  # More uniform for Dirichlet
    f0 = pyro.sample("bias0", dist.Dirichlet(prob0))
    f1 = pyro.sample("bias1", dist.Dirichlet(prob1))

    with pyro.plate("z_minibatch", len(data), batch_size) as ind:
        f = choose*f1 + (1-choose)*f0
        pyro.sample("obs", dist.Categorical(f), obs=data[ind])

def guide(data):
    prob0 = pyro.param("probs0", torch.ones(n_faces) * 10., constraint=constraints.positive)
    prob1 = pyro.param("probs1", torch.ones(n_faces) * 10., constraint=constraints.positive)
    pyro.sample("bias0", dist.Dirichlet(prob0))
    pyro.sample("bias1", dist.Dirichlet(prob1))

data = createData()

# setup the optimizer
adam_params = {"lr": 0.001, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, guide, optimizer, loss=elbo) # orig

# do gradient steps
for step in range(n_svi_steps):
    loss = svi.step(data)

    if step % 100 == 0:
        print("loss= ", step, loss)

print("Die 0 probs: ", pyro.param("probs0"))
print("Die 1 probs: ", pyro.param("probs1"))

Anybody there?

I think you’ll need to change infer={"enumerate": "sequential"} to infer={"enumerate": "parallel"} in your model. You might find the GMM tutorial (Gaussian Mixture Model — Pyro Tutorials 1.8.4 documentation) useful for this example. Quoting from the tutorial:

When enumerating variables in the model, the variables must be enumerated in parallel and must not appear in the guide. Mathematically, guide-side enumeration simply reduces variance in a stochastic ELBO by enumerating all values, whereas model-side enumeration avoids an application of Jensen’s inequality by exactly marginalizing out a variable.

I don’t think the shapes would still align correctly, but something like the following should work (not tested for correctness), where you define a 2x6 “bias” vector and index into that directly:

    with pyro.plate("select_dice", 2):
        f = pyro.sample("bias", dist.Dirichlet(torch.ones(n_faces) * 10))
    with pyro.plate("z_minibatch", len(data), batch_size) as ind:
        choose = pyro.sample("choose", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}).long()
        pyro.sample("obs", dist.Categorical(f[choose]), obs=data[ind])