Proper use of infer_discrete and brittle posterior

I want to implement a structural model where a discrete variable determines the distribution of two other discrete random variables. My question is what is the correct way to approach this problem in pyro? I have a toy example and implementation below:

Question 1: Is my implementation below correct?

Question 2: The posterior distribution on s is very brittle, i.e., it is either almost 0 or almost 1. I was expecting to see more uncertainty, especially for smaller m and n (e.g., m = 15, n = 1). For larger values of m and n, the posterior puts almost all its mass on the correct values (direction) of s.

The data

We observe m datasets, each comprising n pairs of {(x_{i,j}, y_{i,j})}, i \in 1…m and j \in 1…n, where x_{i,j} and y_{i,j} are binary r.v.s.
The value of cause indexes the parameter of the effect distribution.

The data generating statistical model is as follows:

// Pick whether x is the cause (s == 0) or y is the cause (s == 1)
s ~ Bern(.5)
// Pick the conditional distribution of the effect given cause 
pi_0 ~ Beta(1,1)
pi_1 ~ Beta(1,1)

// Simulate m datasets
for i in 1 ... m:
	// Pick the probability of cause
	prob_of_cause_i ~ Beta(1,1)

	// Simulate pairs of cause and effect depending on value of s
	for j in 1 ... n:
		if s == 0:
		// X is the cause, y is the effect
		      x_{i,j} ~ Bern(prob_of_cause_i)
		      // Pick the effect based on the value of cause
		      y_{i,j} ~ Bern(pi_{x_{i,j}})
		// Y is the cause, x is the effect
		if s == 1:
		      y_{i,j} ~ Bern(prob_of_cause_i)
		      // Pick the effect based on the value of cause
		      x_{i,j} ~ Bern(pi_{x_{i,j}})
	report (x, y)

The model

Please see my model and inference routine below.

import logging
import torch
import numpy as np
#import pandas as pd
#import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rc
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
smoke_test = False
assert pyro.__version__.startswith('1.8.0')
logging.basicConfig(format='%(message)s', level=logging.INFO)

# The model

def model(data = None, m = None, n = None):
    s = pyro.sample('s', dist.Bernoulli(.5))

    pi_0 = pyro.sample('pi_0', dist.Beta(1.0, 1.0))
    pi_1 = pyro.sample('pi_1', dist.Beta(1.0, 1.0))

    with pyro.plate("group_plate", m, dim=-1):
        prob_of_cause = pyro.sample('prob_of_cause', dist.Beta(1.0, 1.0))
        with pyro.plate("data_plate", n, dim=-2):
            cause_obs = ((1-s)*data['x'] + s*data['y'] if data is not None else None)
            effect_obs = ((1-s)*data['y'] + s*data['x'] if data is not None else None)
            cause = pyro.sample('cause', dist.Bernoulli(prob_of_cause), obs = cause_obs)
            prob_of_effect = (pi_0**cause)*(pi_1**cause)
            effect = pyro.sample('effect', dist.Bernoulli( prob_of_effect ), obs = effect_obs)
    x = (1-s)*cause + s*effect
    y = (1-s)*effect + s*cause  
    return({'x': x, 'y': y})          

# Then generate some data from the model:

m, n = 15, 1
trace_model = pyro.poutine.trace(model).get_trace(None, m, n)
output_ = trace_model.nodes['_RETURN']['value']
s_orig = trace_model.nodes['s']['value']
print(f's_orig = {s_orig}')
data = {'x': output_['x'], 'y': output_['y']}

## The inference

# Now for inference, first define the guide for continuous r.v.s:

mvn_guide = pyro.infer.autoguide.AutoDiagonalNormal(pyro.poutine.block(model, hide=['cause', 'effect', 's']))

# Setup SVI:

svi = pyro.infer.SVI(model,
                     pyro.optim.Adam({"lr": 0.01}),

losses = []
for step in range(1000 if not smoke_test else 2):
    loss = svi.step(data, m, n)
    if step % 100 == 0:"Elbo loss @{}: {}".format(step, loss))

# Sample `s` from the posterior:

ss = []
for i in range(1000):
    guide_trace = pyro.poutine.trace(mvn_guide).get_trace()
    # Condition the full_model on the sampled params        
    model_map = pyro.poutine.replay(model, trace=guide_trace)
    trainded_model = infer_discrete(model_map, first_available_dim=-6)
    trace_model_map = pyro.poutine.trace(trainded_model).get_trace(data, m, n)
    s_inferred = trace_model_map.nodes['s']['value'].detach().numpy()
print(f's_orig = {s_orig} and P(s = {s_orig}|D) = {np.average(ss) if s_orig == 1 else 1.0 - np.average(ss)}')

Hi @sohrab, great example model! I think this is a case where the marginal posterior is too complex for a simple AutoDiagonalNormal guide. Instead I’d recommend hand-coding a guide with pi posteriors that depend on s and using guide-side enumeration for training. Here’s a sketch of this approach

# the model is not enumerated
def model(...):

@config_enumerate  # the guide is enumerated
def guide(data=None, m=None, n=None):
    s_logits = pyro.param("s_logits", torch.zeros(()))
    s = pyro.sample("s", dist.Bernoulli(logits=s_logits))

    # How about we pack all the pi parameters into a single tensor?
    # There are eight parameters:
    # (2 values of s) x (2 sample sites) x (2 parameters to Beta)
    concentration = pyro.param("concentration", torch.ones(2, 2, 2),
    concentration = (concentration[0] * s[..., None, None] +
                     concentration[1] * (1- s[..., None, None]))
    alpha, beta = concentration.unbind(-1)
    pi_0 = pyro.sample("pi_0", dist.Beta(alpha[..., 0], beta[..., 0]))
    pi_1 = pyro.sample("pi_1", dist.Beta(alpha[..., 1], beta[..., 1]))

    # Similarly we'll back the cause parameters into a single tensor.
    cause_param = pyro.param("cause_param", torch.ones(m, 2, 2),
    cause_param = (cause_param[:. 0] * s[..., None] +
                   cause_param[:, 1] * (1 - s[..., None])
    with pyro.plate("group_plate", m, dim=-1):
        prob_of_cause = pyro.sample(
            'prob_of_cause', dist.Beta(cause_param[..., 0], cause_param[..., 1])

That guide will learn both the posterior over s and, conditioned on s, mean field posteriors over pi_0 and pi_1. Note we’re using lots of indexing on the right so as to be compatible with enumeration (these tricks are described in the tensor shapes tutorial). Let me know if you get that working :smile:

Thank you for your prompt and thorough response. Looks great! I’ll work on it and will let you know how it goes.