Marginal probability of single assignment to multiple variables

Let’s say I have a simple model.

def model():
    z = sample("z", Normal(0, 1))
    return sample("a", Normal(z, 1)), sample("b", Normal(z, 1))

My question is: what’s the simplest way to evaluate the marginal joint probability of a single assignment to a and b? e.g.

P(a = 1, b = 1) = \int_z P(a = 1, b = 1, z = z)dz

For example, I can use trace to find the probability of a joint assignment:

trace(condition(model, {"a": 1, "b": 1})).get_trace().log_prob_sum()

But that doesn’t marginalize over z. I can use EmpiricalMarginal to compute a marginal distribution:

EmpiricalMarginal(Importance(model).run(), sites=["a", "b"])

But this won’t assign probability mass to any individual sample I can bring, since any specific assignment like a: 1, b: 1 won’t likely be sampled.

Also, as an aside—is replay the most appropriate way to evaluate a joint probability using trace, as opposed to condition? If so, are there any facilities for manually constructing Trace objects?

I think I have a simple Monte Carlo solution using importance sampling, in case anyone else would find this useful.

# Compute probability of observed variables marginalizing out the latents
# obs is a dict of variable name: value
# latents is a list of variable names
def marginal_sample_prob(model, obs, latents, num_samples=100):
    # Draw samples from posterior over latent variables conditioning on observations
    cond_model = cond(model, obs)
    marginal_approx_dist = infer.Importance(cond_model, num_samples=num_samples).run()
    empirical_marginal = infer.EmpiricalMarginal(marginal_approx_dist, sites=latents)
    
    # Compute marginal probability as an expectation using Monte Carlo estimate on posterior samples
    obs_probs = torch.zeros(num_samples)
    for i in range(num_samples):
        latent_vals_list = empirical_marginal()        
        latent_vals_dict = {latent: val for latent, val in zip(latents, latent_vals_list)}
        
        # Using "do" ensures that the values of latent variables will not affect the probability
        # of the trace
        obs_prob = trace(do(cond_model, latent_vals_dict)).get_trace().log_prob_sum()
        obs_prob += empirical_marginal.log_prob(latent_vals_list)
        obs_probs[i] = obs_prob

    # Sum probabilities in a numerically stable way
    return torch.logsumexp(obs_probs, dim=0)

def model():   
    z = sample("z", dist.Normal(0, 1))
    return sample("a", dist.Normal(z, 1)), sample("b", dist.Normal(z, 1))

print(sample_prob(model, {"a": 1, "b": 1}, ["z"]).exp())
1 Like