I want to implement a structural model where a discrete variable determines the distribution of two other discrete random variables. My question is what is the correct way to approach this problem in pyro? I have a toy example and implementation below:

`Question 1`

: Is my implementation below correct?

`Question 2`

: The posterior distribution on `s`

is very brittle, i.e., it is either almost 0 or almost 1. I was expecting to see more uncertainty, especially for smaller `m`

and `n`

(e.g., `m`

= 15, `n`

= 1). For larger values of `m`

and `n`

, the posterior puts almost all its mass on the correct values (direction) of `s`

.

## The data

We observe `m`

datasets, each comprising `n`

pairs of `{(x_{i,j}, y_{i,j})}`

, `i`

\in 1…m and `j`

\in 1…n, where `x_{i,j}`

and `y_{i,j}`

are binary r.v.s.

The value of cause indexes the parameter of the effect distribution.

The data generating statistical model is as follows:

```
// Pick whether x is the cause (s == 0) or y is the cause (s == 1)
s ~ Bern(.5)
// Pick the conditional distribution of the effect given cause
pi_0 ~ Beta(1,1)
pi_1 ~ Beta(1,1)
// Simulate m datasets
for i in 1 ... m:
// Pick the probability of cause
prob_of_cause_i ~ Beta(1,1)
// Simulate pairs of cause and effect depending on value of s
for j in 1 ... n:
if s == 0:
// X is the cause, y is the effect
x_{i,j} ~ Bern(prob_of_cause_i)
// Pick the effect based on the value of cause
y_{i,j} ~ Bern(pi_{x_{i,j}})
// Y is the cause, x is the effect
if s == 1:
y_{i,j} ~ Bern(prob_of_cause_i)
// Pick the effect based on the value of cause
x_{i,j} ~ Bern(pi_{x_{i,j}})
report (x, y)
```

## The model

Please see my model and inference routine below.

```
import logging
import torch
import numpy as np
#import pandas as pd
#import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rc
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
smoke_test = False
assert pyro.__version__.startswith('1.8.0')
pyro.enable_validation(True)
pyro.set_rng_seed(4)
logging.basicConfig(format='%(message)s', level=logging.INFO)
# The model
@config_enumerate
def model(data = None, m = None, n = None):
s = pyro.sample('s', dist.Bernoulli(.5))
pi_0 = pyro.sample('pi_0', dist.Beta(1.0, 1.0))
pi_1 = pyro.sample('pi_1', dist.Beta(1.0, 1.0))
with pyro.plate("group_plate", m, dim=-1):
prob_of_cause = pyro.sample('prob_of_cause', dist.Beta(1.0, 1.0))
with pyro.plate("data_plate", n, dim=-2):
cause_obs = ((1-s)*data['x'] + s*data['y'] if data is not None else None)
effect_obs = ((1-s)*data['y'] + s*data['x'] if data is not None else None)
cause = pyro.sample('cause', dist.Bernoulli(prob_of_cause), obs = cause_obs)
prob_of_effect = (pi_0**cause)*(pi_1**cause)
effect = pyro.sample('effect', dist.Bernoulli( prob_of_effect ), obs = effect_obs)
x = (1-s)*cause + s*effect
y = (1-s)*effect + s*cause
return({'x': x, 'y': y})
# Then generate some data from the model:
m, n = 15, 1
trace_model = pyro.poutine.trace(model).get_trace(None, m, n)
output_ = trace_model.nodes['_RETURN']['value']
s_orig = trace_model.nodes['s']['value']
print(f's_orig = {s_orig}')
data = {'x': output_['x'], 'y': output_['y']}
## The inference
# Now for inference, first define the guide for continuous r.v.s:
mvn_guide = pyro.infer.autoguide.AutoDiagonalNormal(pyro.poutine.block(model, hide=['cause', 'effect', 's']))
# Setup SVI:
pyro.clear_param_store()
svi = pyro.infer.SVI(model,
mvn_guide,
pyro.optim.Adam({"lr": 0.01}),
TraceEnum_ELBO(max_plate_nesting=5))
losses = []
for step in range(1000 if not smoke_test else 2):
loss = svi.step(data, m, n)
losses.append(loss)
if step % 100 == 0:
logging.info("Elbo loss @{}: {}".format(step, loss))
# Sample `s` from the posterior:
ss = []
for i in range(1000):
guide_trace = pyro.poutine.trace(mvn_guide).get_trace()
# Condition the full_model on the sampled params
model_map = pyro.poutine.replay(model, trace=guide_trace)
trainded_model = infer_discrete(model_map, first_available_dim=-6)
trace_model_map = pyro.poutine.trace(trainded_model).get_trace(data, m, n)
s_inferred = trace_model_map.nodes['s']['value'].detach().numpy()
ss.append(s_inferred)
print(f's_orig = {s_orig} and P(s = {s_orig}|D) = {np.average(ss) if s_orig == 1 else 1.0 - np.average(ss)}')
```