With poutine.trace(prior).get_trace([10], ‘cpu’).log_prob_sum() I can calculate the pdf of the whole batch. However, I want to calculate the pdf of the individual samples in the batch and not all of them combined. What is the best way to do that?
you just need to inspect the trace in great detail. e.g.:
my_trace = poutine.trace(prior).get_trace(...)
for name, site in my_trace.nodes.items():
if site["type"] == "sample":
print(name, site["log_prob"].shape)
Thanks, for the quick answer! Unfortunately, this doesn’t quite work. Your code gives a key error for “log_prob”. Using the debugger the site object looks like this (for a batch of 3):
The documentation on Traces is unfortunately not that detailed but it mentions that log_prob_sum allows you to specify a site filter as a lambda function. Any idea how that would look like?
Okay, so it looks like you need to use my_trace.nodes[’_RETURN’][‘value’] for that and you can use poutine.broadcast for more elegant batching. So overall, this is the full solution (if you need the return value instead of the samples values):