How to compute joint probability distribution including deterministic sites?

Short Overview

I would appreciate some help on how to compute the joint probability distribution over several enumerated sample sites, if some sites are pyro.deterministic.

Below, I describe a concrete example that I would like to adapt.


Basic Situation

Suppose we have a model with categorical sample sites and a deterministic sample site:

from typing import *
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import config_enumerate,TraceEnum_ELBO


@config_enumerate
def model_1(obs:Dict) :
    samples = dict()
    samples['a'] = pyro.sample(name='a',
                               fn=dist.Categorical(probs=torch.tensor([0.5,0.5])),
                               obs=obs['a'] if 'a' in obs else None)

    b_given_a = torch.as_tensor([[0.2,0.8],[0.4,0.6]])
    samples['b'] = pyro.sample(name='b',
                               fn=dist.Categorical(probs= b_given_a[samples['a']]),
                               obs = obs['b'] if 'b' in obs else None)



    c_given_a = torch.tensor([0,1])
    samples['c'] = pyro.deterministic(name='c', value=c_given_a[samples['a']])

    return samples

Joint distribution over pyro.sample sites:

We can compute the joint distribution
p(a=0,b=0) = p(a=0) * p(b=0|a=0) = 0.5 * 0.2 = 0.1
by the chain rule using the marginal distributions as follows:

p_joint=1
joint_comb={'a':torch.tensor([0]),'b':torch.tensor([0])}
given_vars=dict()
for var in ['a','b']:
    margs=TraceEnum_ELBO(max_plate_nesting=1).compute_marginals(model=model_1,
                                                          guide=lambda *args, **kwargs: None,
                                                          obs=given_vars)

    p_joint*=margs[var].probs[joint_comb[var]]
    given_vars[var]=joint_comb[var]

#We see: the result is correct:
print(f'p(a=0,b=0): {p_joint}')

Joint distribution over pyro.deterministic sites:

However, if we try to include pyro.deterministic sites as in
p(a=0, c=0) = p(a=0) * p(c=0|a=0) = 0.5 * 1 = 0.5

we get a key error:

p_joint=1
joint_comb={'a':torch.tensor([0]),'c':torch.tensor([0])}
given_vars=dict()
for var in ['a','c']:
    margs=TraceEnum_ELBO(max_plate_nesting=1).compute_marginals(model=model,
                                                          guide=lambda *args, **kwargs: None,
                                                          obs=given_vars)

    #This line leads to key error for deterministic sites.
    p_joint*=margs[var].probs[joint_comb[var]]

    given_vars[var]=joint_comb[var]

print(f'p(a=0,c=0): {p_joint}')

Quick workaround to include pyro.deterministic sites:

We can circumvent the problem, by defining c as ‘dummy categorical’ as in the following model, where just the site c was altered, compared to the former model:

@config_enumerate
def model_2(obs:Dict) :
    samples = dict()
    samples['a'] = pyro.sample(name='a',
                               fn=dist.Categorical(probs=torch.tensor([0.5,0.5])),
                               obs=obs['a'] if 'a' in obs else None)

    b_given_a = torch.as_tensor([[0.2,0.8],[0.4,0.6]])
    samples['b'] = pyro.sample(name='b',
                               fn=dist.Categorical(probs= b_given_a[samples['a']])
                               )



    #we make c a sample site instead of deterministic, with 'deterministic probs'
    c_given_a = torch.as_tensor([[1,0],[0,1]])
    samples['c'] = pyro.sample(name='c',
                               fn=dist.Categorical(probs=c_given_a[samples['a']])
                               )

    return samples

With this change, we can compute the joint probability:
p(a=0, c=0) = p(a=0) * p(c=0|a=0) = 0.5 * 1 = 0.5

In the same manner as we did before, without resulting key errors:

p_joint=1
joint_comb={'a':torch.tensor([0]),'c':torch.tensor([0])}
given_vars=dict()
for var in ['a','c']:
    margs=TraceEnum_ELBO(max_plate_nesting=1).compute_marginals(model=model_2,
                                                          guide=lambda *args, **kwargs: None,
                                                          obs=given_vars)

    #This line leads to key error for deterministic sites.
    p_joint*=margs[var].probs[joint_comb[var]]
    given_vars[var]=joint_comb[var]

#We see: the result is correct:
print(f'p(a=0,c=0): {p_joint}')

Looking for a better solution:

However, this is somehow hacky and inserts an additional enumeration dimension for c.
This is something I would like to avoid, since the sites are limited by pytorch and my models will be pretty big.


Final questions:

Main question:

  1. How can I include deterministic sites without declaring additional enumeration sites?

    Just hardcoding p(c) is no option, since in my real models it will deterministically depend on a combination of several former sample sites, such that I would need to sum over all combinations of the former sites c depends on (cumbersome/slow).

Additional:

  1. Is there maybe an easier way to compute described joint probabilities using pyro internal methods or something like this?

I am happy about any help. Thanks in advance. :slight_smile:

I can now answer the question 2. on my own:

A better way to compute the joint probs would be:

#p(a=0,b=0):
subfixed_model=pyro.poutine.enum(pyro.condition(model_1, data={'a':torch.tensor([0]),'b':torch.tensor([0])}), first_available_dim=-1)
trace = pyro.poutine.trace(subfixed_model).get_trace(obs=dict())
print(trace.log_prob_sum().exp()) #joint probability

However I am still struggling including deterministic sites.
Since Observing deterministically transformed output · Issue #568 · pyro-ppl/pyro · GitHub says that conditioning is just possible over sample sites, I replaced pyro.deterministic by a sample statement with the Delta distribution:

samples['c'] = pyro.sample(name='c',
                          fn=dist.Delta(v=c_given_a[samples['a']]),
                          obs = obs['c'] if 'c' in obs else None)

This leads to the expected result for impossible values of c:

#p(a=0,c=1)= p(a=0) * p(c=1|a=0) = 0.5 * 0 = 0
subfixed_model=pyro.poutine.enum(pyro.condition(model_1, data={'a':torch.tensor([0]),'c':torch.tensor([1])}), first_available_dim=-1)
trace = pyro.poutine.trace(subfixed_model).get_trace(obs=dict())
print(trace.log_prob_sum().exp()) #result is: 0.0

but somehow fails for the expected input of c:

#p(a=0,c=0)= p(a=0) * p(c=0|a=0) = 0.5 * 1 = 0.5
subfixed_model=pyro.poutine.enum(pyro.condition(model_1, data={'a':torch.tensor([0]),'c':torch.tensor([0])}), first_available_dim=-1)
trace = pyro.poutine.trace(subfixed_model).get_trace(obs=dict())
print(trace.log_prob_sum().exp()) #result is: 0.0800

Any ideas what I am missing?

Well, I also got the problem over here, looking into the code of log_prob_sum().

It actually computes the sum of all the log prob sums of different sites, thus in former case it computes:

(dist.Categorical(probs=torch.tensor([0.5,0.5])).log_prob(torch.tensor([0]))+(torch.log(torch.tensor(0.2))+torch.log(torch.tensor(0.8)))+dist.Delta(v=torch.tensor([0])).log_prob(torch.tensor([0]))).exp() #result is: 0.0800

Final solution of joint prob computation including deterministic sites:

The model:

@config_enumerate
def model_1(obs:Dict):
    samples = dict()

    with pyro.plate('batch'):
        samples['a'] = pyro.sample(name='a',
                                   fn=dist.Categorical(probs=torch.tensor([0.5,0.5])),
                                   obs=obs['a'] if 'a' in obs else None)

        b_given_a = torch.as_tensor([[0.2, 0.8], [1.0, 0.0]])
        samples['b'] = pyro.sample(name='b',
                                   fn=dist.Categorical(probs= b_given_a[samples['a']]),
                                   obs = obs['b'] if 'b' in obs else None)



        c_given_a = torch.tensor([0,1])


        #we use the same distribution as pyro.deterministic, but as pyro.sample site, to be able to include
        #observations
        samples['c'] = pyro.sample(name='c',
                                  fn=dist.Delta(v=c_given_a[samples['a']]),
                                  obs = obs['c'] if 'c' in obs else None)

    return samples

Exemplary joint prob computation:

#p(a=0,c=0)= p(a=0) * p(c=0|a=0) = 0.5 * 1 = 0.5
subfixed_model=pyro.poutine.enum(pyro.condition(model_1, data={'a':torch.tensor([0]),'c':torch.tensor([0])}), first_available_dim=-1)
trace = pyro.poutine.trace(subfixed_model).get_trace(obs=dict())
print(pyro.infer.mcmc.util.TraceEinsumEvaluator(trace, has_enumerable_sites=True, max_plate_nesting=1).log_prob(trace).exp())

So I consider the issue as solved.