 # Marginal probability of single assignment to multiple variables

Let’s say I have a simple model.

def model():
z = sample("z", Normal(0, 1))
return sample("a", Normal(z, 1)), sample("b", Normal(z, 1))


My question is: what’s the simplest way to evaluate the marginal joint probability of a single assignment to a and b? e.g.

P(a = 1, b = 1) = \int_z P(a = 1, b = 1, z = z)dz

For example, I can use trace to find the probability of a joint assignment:

trace(condition(model, {"a": 1, "b": 1})).get_trace().log_prob_sum()


But that doesn’t marginalize over z. I can use EmpiricalMarginal to compute a marginal distribution:

EmpiricalMarginal(Importance(model).run(), sites=["a", "b"])


But this won’t assign probability mass to any individual sample I can bring, since any specific assignment like a: 1, b: 1 won’t likely be sampled.

Also, as an aside—is replay the most appropriate way to evaluate a joint probability using trace, as opposed to condition? If so, are there any facilities for manually constructing Trace objects?

I think I have a simple Monte Carlo solution using importance sampling, in case anyone else would find this useful.

# Compute probability of observed variables marginalizing out the latents
# obs is a dict of variable name: value
# latents is a list of variable names
def marginal_sample_prob(model, obs, latents, num_samples=100):
# Draw samples from posterior over latent variables conditioning on observations
cond_model = cond(model, obs)
marginal_approx_dist = infer.Importance(cond_model, num_samples=num_samples).run()
empirical_marginal = infer.EmpiricalMarginal(marginal_approx_dist, sites=latents)

# Compute marginal probability as an expectation using Monte Carlo estimate on posterior samples
obs_probs = torch.zeros(num_samples)
for i in range(num_samples):
latent_vals_list = empirical_marginal()
latent_vals_dict = {latent: val for latent, val in zip(latents, latent_vals_list)}

# Using "do" ensures that the values of latent variables will not affect the probability
# of the trace
obs_prob = trace(do(cond_model, latent_vals_dict)).get_trace().log_prob_sum()
obs_prob += empirical_marginal.log_prob(latent_vals_list)
obs_probs[i] = obs_prob

# Sum probabilities in a numerically stable way