# How to compute joint probability distribution including deterministic sites?

Short Overview

I would appreciate some help on how to compute the joint probability distribution over several enumerated sample sites, if some sites are pyro.deterministic.

Below, I describe a concrete example that I would like to adapt.

Basic Situation

Suppose we have a model with categorical sample sites and a deterministic sample site:

``````from typing import *
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import config_enumerate,TraceEnum_ELBO

@config_enumerate
def model_1(obs:Dict) :
samples = dict()
samples['a'] = pyro.sample(name='a',
fn=dist.Categorical(probs=torch.tensor([0.5,0.5])),
obs=obs['a'] if 'a' in obs else None)

b_given_a = torch.as_tensor([[0.2,0.8],[0.4,0.6]])
samples['b'] = pyro.sample(name='b',
fn=dist.Categorical(probs= b_given_a[samples['a']]),
obs = obs['b'] if 'b' in obs else None)

c_given_a = torch.tensor([0,1])
samples['c'] = pyro.deterministic(name='c', value=c_given_a[samples['a']])

return samples
``````

Joint distribution over pyro.sample sites:

We can compute the joint distribution
p(a=0,b=0) = p(a=0) * p(b=0|a=0) = 0.5 * 0.2 = 0.1
by the chain rule using the marginal distributions as follows:

``````p_joint=1
joint_comb={'a':torch.tensor([0]),'b':torch.tensor([0])}
given_vars=dict()
for var in ['a','b']:
margs=TraceEnum_ELBO(max_plate_nesting=1).compute_marginals(model=model_1,
guide=lambda *args, **kwargs: None,
obs=given_vars)

p_joint*=margs[var].probs[joint_comb[var]]
given_vars[var]=joint_comb[var]

#We see: the result is correct:
print(f'p(a=0,b=0): {p_joint}')
``````

Joint distribution over pyro.deterministic sites:

However, if we try to include pyro.deterministic sites as in
p(a=0, c=0) = p(a=0) * p(c=0|a=0) = 0.5 * 1 = 0.5

we get a key error:

``````p_joint=1
joint_comb={'a':torch.tensor([0]),'c':torch.tensor([0])}
given_vars=dict()
for var in ['a','c']:
margs=TraceEnum_ELBO(max_plate_nesting=1).compute_marginals(model=model,
guide=lambda *args, **kwargs: None,
obs=given_vars)

#This line leads to key error for deterministic sites.
p_joint*=margs[var].probs[joint_comb[var]]

given_vars[var]=joint_comb[var]

print(f'p(a=0,c=0): {p_joint}')
``````

Quick workaround to include pyro.deterministic sites:

We can circumvent the problem, by defining c as ‘dummy categorical’ as in the following model, where just the site c was altered, compared to the former model:

``````@config_enumerate
def model_2(obs:Dict) :
samples = dict()
samples['a'] = pyro.sample(name='a',
fn=dist.Categorical(probs=torch.tensor([0.5,0.5])),
obs=obs['a'] if 'a' in obs else None)

b_given_a = torch.as_tensor([[0.2,0.8],[0.4,0.6]])
samples['b'] = pyro.sample(name='b',
fn=dist.Categorical(probs= b_given_a[samples['a']])
)

#we make c a sample site instead of deterministic, with 'deterministic probs'
c_given_a = torch.as_tensor([[1,0],[0,1]])
samples['c'] = pyro.sample(name='c',
fn=dist.Categorical(probs=c_given_a[samples['a']])
)

return samples
``````

With this change, we can compute the joint probability:
p(a=0, c=0) = p(a=0) * p(c=0|a=0) = 0.5 * 1 = 0.5

In the same manner as we did before, without resulting key errors:

``````p_joint=1
joint_comb={'a':torch.tensor([0]),'c':torch.tensor([0])}
given_vars=dict()
for var in ['a','c']:
margs=TraceEnum_ELBO(max_plate_nesting=1).compute_marginals(model=model_2,
guide=lambda *args, **kwargs: None,
obs=given_vars)

#This line leads to key error for deterministic sites.
p_joint*=margs[var].probs[joint_comb[var]]
given_vars[var]=joint_comb[var]

#We see: the result is correct:
print(f'p(a=0,c=0): {p_joint}')
``````

Looking for a better solution:

However, this is somehow hacky and inserts an additional enumeration dimension for c.
This is something I would like to avoid, since the sites are limited by pytorch and my models will be pretty big.

Final questions:

Main question:

1. How can I include deterministic sites without declaring additional enumeration sites?

Just hardcoding p(c) is no option, since in my real models it will deterministically depend on a combination of several former sample sites, such that I would need to sum over all combinations of the former sites c depends on (cumbersome/slow).

1. Is there maybe an easier way to compute described joint probabilities using pyro internal methods or something like this?

I can now answer the question 2. on my own:

A better way to compute the joint probs would be:

``````#p(a=0,b=0):
subfixed_model=pyro.poutine.enum(pyro.condition(model_1, data={'a':torch.tensor([0]),'b':torch.tensor([0])}), first_available_dim=-1)
trace = pyro.poutine.trace(subfixed_model).get_trace(obs=dict())
print(trace.log_prob_sum().exp()) #joint probability
``````

However I am still struggling including deterministic sites.
Since Observing deterministically transformed output · Issue #568 · pyro-ppl/pyro · GitHub says that conditioning is just possible over sample sites, I replaced pyro.deterministic by a sample statement with the Delta distribution:

``````samples['c'] = pyro.sample(name='c',
fn=dist.Delta(v=c_given_a[samples['a']]),
obs = obs['c'] if 'c' in obs else None)
``````

This leads to the expected result for impossible values of c:

``````#p(a=0,c=1)= p(a=0) * p(c=1|a=0) = 0.5 * 0 = 0
subfixed_model=pyro.poutine.enum(pyro.condition(model_1, data={'a':torch.tensor([0]),'c':torch.tensor([1])}), first_available_dim=-1)
trace = pyro.poutine.trace(subfixed_model).get_trace(obs=dict())
print(trace.log_prob_sum().exp()) #result is: 0.0
``````

but somehow fails for the expected input of c:

``````#p(a=0,c=0)= p(a=0) * p(c=0|a=0) = 0.5 * 1 = 0.5
subfixed_model=pyro.poutine.enum(pyro.condition(model_1, data={'a':torch.tensor([0]),'c':torch.tensor([0])}), first_available_dim=-1)
trace = pyro.poutine.trace(subfixed_model).get_trace(obs=dict())
print(trace.log_prob_sum().exp()) #result is: 0.0800
``````

Any ideas what I am missing?

Well, I also got the problem over here, looking into the code of log_prob_sum().

It actually computes the sum of all the log prob sums of different sites, thus in former case it computes:

``````(dist.Categorical(probs=torch.tensor([0.5,0.5])).log_prob(torch.tensor([0]))+(torch.log(torch.tensor(0.2))+torch.log(torch.tensor(0.8)))+dist.Delta(v=torch.tensor([0])).log_prob(torch.tensor([0]))).exp() #result is: 0.0800
``````

Final solution of joint prob computation including deterministic sites:

The model:

``````@config_enumerate
def model_1(obs:Dict):
samples = dict()

with pyro.plate('batch'):
samples['a'] = pyro.sample(name='a',
fn=dist.Categorical(probs=torch.tensor([0.5,0.5])),
obs=obs['a'] if 'a' in obs else None)

b_given_a = torch.as_tensor([[0.2, 0.8], [1.0, 0.0]])
samples['b'] = pyro.sample(name='b',
fn=dist.Categorical(probs= b_given_a[samples['a']]),
obs = obs['b'] if 'b' in obs else None)

c_given_a = torch.tensor([0,1])

#we use the same distribution as pyro.deterministic, but as pyro.sample site, to be able to include
#observations
samples['c'] = pyro.sample(name='c',
fn=dist.Delta(v=c_given_a[samples['a']]),
obs = obs['c'] if 'c' in obs else None)

return samples
``````

Exemplary joint prob computation:

``````#p(a=0,c=0)= p(a=0) * p(c=0|a=0) = 0.5 * 1 = 0.5
subfixed_model=pyro.poutine.enum(pyro.condition(model_1, data={'a':torch.tensor([0]),'c':torch.tensor([0])}), first_available_dim=-1)
trace = pyro.poutine.trace(subfixed_model).get_trace(obs=dict())
print(pyro.infer.mcmc.util.TraceEinsumEvaluator(trace, has_enumerable_sites=True, max_plate_nesting=1).log_prob(trace).exp())

``````

So I consider the issue as solved.