Short Overview
I would appreciate some help on how to compute the joint probability distribution over several enumerated sample sites, if some sites are pyro.deterministic.
Below, I describe a concrete example that I would like to adapt.
Basic Situation
Suppose we have a model with categorical sample sites and a deterministic sample site:
from typing import *
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import config_enumerate,TraceEnum_ELBO
@config_enumerate
def model_1(obs:Dict) :
samples = dict()
samples['a'] = pyro.sample(name='a',
fn=dist.Categorical(probs=torch.tensor([0.5,0.5])),
obs=obs['a'] if 'a' in obs else None)
b_given_a = torch.as_tensor([[0.2,0.8],[0.4,0.6]])
samples['b'] = pyro.sample(name='b',
fn=dist.Categorical(probs= b_given_a[samples['a']]),
obs = obs['b'] if 'b' in obs else None)
c_given_a = torch.tensor([0,1])
samples['c'] = pyro.deterministic(name='c', value=c_given_a[samples['a']])
return samples
Joint distribution over pyro.sample sites:
We can compute the joint distribution
p(a=0,b=0) = p(a=0) * p(b=0a=0) = 0.5 * 0.2 = 0.1
by the chain rule using the marginal distributions as follows:
p_joint=1
joint_comb={'a':torch.tensor([0]),'b':torch.tensor([0])}
given_vars=dict()
for var in ['a','b']:
margs=TraceEnum_ELBO(max_plate_nesting=1).compute_marginals(model=model_1,
guide=lambda *args, **kwargs: None,
obs=given_vars)
p_joint*=margs[var].probs[joint_comb[var]]
given_vars[var]=joint_comb[var]
#We see: the result is correct:
print(f'p(a=0,b=0): {p_joint}')
Joint distribution over pyro.deterministic sites:
However, if we try to include pyro.deterministic sites as in
p(a=0, c=0) = p(a=0) * p(c=0a=0) = 0.5 * 1 = 0.5
we get a key error:
p_joint=1
joint_comb={'a':torch.tensor([0]),'c':torch.tensor([0])}
given_vars=dict()
for var in ['a','c']:
margs=TraceEnum_ELBO(max_plate_nesting=1).compute_marginals(model=model,
guide=lambda *args, **kwargs: None,
obs=given_vars)
#This line leads to key error for deterministic sites.
p_joint*=margs[var].probs[joint_comb[var]]
given_vars[var]=joint_comb[var]
print(f'p(a=0,c=0): {p_joint}')
Quick workaround to include pyro.deterministic sites:
We can circumvent the problem, by defining c as ‘dummy categorical’ as in the following model, where just the site c was altered, compared to the former model:
@config_enumerate
def model_2(obs:Dict) :
samples = dict()
samples['a'] = pyro.sample(name='a',
fn=dist.Categorical(probs=torch.tensor([0.5,0.5])),
obs=obs['a'] if 'a' in obs else None)
b_given_a = torch.as_tensor([[0.2,0.8],[0.4,0.6]])
samples['b'] = pyro.sample(name='b',
fn=dist.Categorical(probs= b_given_a[samples['a']])
)
#we make c a sample site instead of deterministic, with 'deterministic probs'
c_given_a = torch.as_tensor([[1,0],[0,1]])
samples['c'] = pyro.sample(name='c',
fn=dist.Categorical(probs=c_given_a[samples['a']])
)
return samples
With this change, we can compute the joint probability:
p(a=0, c=0) = p(a=0) * p(c=0a=0) = 0.5 * 1 = 0.5
In the same manner as we did before, without resulting key errors:
p_joint=1
joint_comb={'a':torch.tensor([0]),'c':torch.tensor([0])}
given_vars=dict()
for var in ['a','c']:
margs=TraceEnum_ELBO(max_plate_nesting=1).compute_marginals(model=model_2,
guide=lambda *args, **kwargs: None,
obs=given_vars)
#This line leads to key error for deterministic sites.
p_joint*=margs[var].probs[joint_comb[var]]
given_vars[var]=joint_comb[var]
#We see: the result is correct:
print(f'p(a=0,c=0): {p_joint}')
Looking for a better solution:
However, this is somehow hacky and inserts an additional enumeration dimension for c.
This is something I would like to avoid, since the sites are limited by pytorch and my models will be pretty big.
Final questions:
Main question:

How can I include deterministic sites without declaring additional enumeration sites?
Just hardcoding p(c) is no option, since in my real models it will deterministically depend on a combination of several former sample sites, such that I would need to sum over all combinations of the former sites c depends on (cumbersome/slow).
Additional:
 Is there maybe an easier way to compute described joint probabilities using pyro internal methods or something like this?
I am happy about any help. Thanks in advance.