Get pdfs for individual samples of a batch from a stochastic function

Hi! I have this stochstic function:

def prior(self, batch_size, device):
            concentration = torch.tensor(3.4, device=device)
            translation_x = (pyro.sample('translation_x', Beta(concentration, concentration).expand(batch_size)) - 0.5) * 6.0
            translation_y = (pyro.sample('translation_y', Beta(concentration, concentration).expand(batch_size)) - 0.5) * 6.0
            return torch.stack((translation_x, translation_y))

With poutine.trace(prior).get_trace([10], ‘cpu’).log_prob_sum() I can calculate the pdf of the whole batch. However, I want to calculate the pdf of the individual samples in the batch and not all of them combined. What is the best way to do that?

you just need to inspect the trace in great detail. e.g.:

my_trace = poutine.trace(prior).get_trace(...)

for name, site in my_trace.nodes.items():
    if site["type"] == "sample":
        print(name, site["log_prob"].shape)

Thanks, for the quick answer! Unfortunately, this doesn’t quite work. Your code gives a key error for “log_prob”. Using the debugger the site object looks like this (for a batch of 3):

{‘type’: ‘sample’, ‘name’: ‘translation_x’, ‘fn’: Beta(), ‘is_observed’: False, ‘args’: (), ‘kwargs’: {}, ‘value’: tensor([0.3447, 0.6366, 0.3450]), ‘infer’: {}, ‘scale’: 1.0, ‘mask’: None, ‘cond_indep_stack’: (), ‘done’: True, ‘stop’: False, ‘continuation’: None}

The documentation on Traces is unfortunately not that detailed but it mentions that log_prob_sum allows you to specify a site filter as a lambda function. Any idea how that would look like?

Ah, I found the reason. After getting the trace you first need to call my_trace.compute_log_prob(). Then it works!

For anyone who faces the same problem, here is the full solution to get both the sample log pdfs and values from the trace:

n_samples = 3
my_trace = poutine.trace(prior).get_trace([n_samples], 'cpu')
my_trace.compute_log_prob()
probs = torch.zeros([n_samples])
value_list = []
for name, site in my_trace.nodes.items():
    if site["type"] == "sample":
        probs += site["log_prob"]
        value_list.append(site["value"])
values = torch.stack(value_list)

However, I still don’t know how to access the actual return value of the stochastic function this way.

Okay, so it looks like you need to use my_trace.nodes[’_RETURN’][‘value’] for that and you can use poutine.broadcast for more elegant batching. So overall, this is the full solution (if you need the return value instead of the samples values):

@poutine.broadcast
def prior(self, batch_size, device):
    with IndepMessenger("batch", batch_size, dim=-1):
        concentration = torch.tensor(3.4, device=device)
        translation_x = (pyro.sample('translation_x', Beta(concentration, concentration)) - 0.5) * 6.0
        translation_y = (pyro.sample('translation_y', Beta(concentration, concentration)) - 0.5) * 6.0
    return torch.stack((translation_x, translation_y))

n_samples = 3
my_trace = poutine.trace(prior).get_trace(n_samples, 'cpu')
my_trace.compute_log_prob()
probs = torch.zeros([n_samples])
for name, site in my_trace.nodes.items():
    if site["type"] == "sample":
        probs += site["log_prob"]
prior_samples = my_trace.nodes['_RETURN']['value']

you shouldn’t need the replay. there is a trace.nodes['_RETURN'] site

Thanks, I’ve updated my code accordingly.