I have a model with an outer plate around an interior model which uses enums.
I was wondering if there is a clever way to break out the TraceEnum_ELBO.loss contribution of each outer plate? I was considering adding another model parameter which would only run a particular outer plate instead of all of them, then making a new AutoGuide, and using it by looping through the outer plates… But this smells a little clunky, and is probably slow.
Is there a more elegant approach? I was hoping for something where I overload a part of the ELBO mechanism (maybe rewrite .loss() ?), but I don’t have a good handle on how that works. Any guidance on a general approach appreciated!
Well… I found using a poutine mask and looping exceptionally slow, but a very small diff gives me what I’m looking for. Probably not general enough for broad use, but throwing here in case it helps anyone. Diff from dev branch.
diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py
index c3b094f7..f38e4ed4 100644
--- a/pyro/infer/traceenum_elbo.py
+++ b/pyro/infer/traceenum_elbo.py
@@ -120,7 +120,7 @@ def _compute_model_factors(model_trace, guide_trace):
return marginal_costs, log_factors, ordering, enum_dims, scale
-def _compute_dice_elbo(model_trace, guide_trace):
+def _compute_dice_elbo(model_trace, guide_trace, breakout=False):
# Accumulate marginal model costs.
marginal_costs, log_factors, ordering, sum_dims, scale = _compute_model_factors(
model_trace, guide_trace)
@@ -155,7 +155,7 @@ def _compute_dice_elbo(model_trace, guide_trace):
cost = packed.neg(site["packed"]["log_prob"])
costs.setdefault(ordering[name], []).append(cost)
- return Dice(guide_trace, ordering).compute_expectation(costs)
+ return Dice(guide_trace, ordering).compute_expectation(costs, breakout)
def _make_dist(dist_, logits):
diff --git a/pyro/infer/util.py b/pyro/infer/util.py
index f20a7ac8..74f2ae31 100644
--- a/pyro/infer/util.py
+++ b/pyro/infer/util.py
@@ -209,7 +209,7 @@ class Dice(object):
return log_factors
- def compute_expectation(self, costs):
+ def compute_expectation(self, costs, breakout=False):
"""
Returns a differentiable expected cost, summing over costs at given ordinals.
@@ -263,7 +263,10 @@ class Dice(object):
cost = cost[mask]
else:
cost, prob = packed.broadcast_all(cost, prob)
- expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim())
+ if not breakout:
+ expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim())
+ else:
+ expected_cost = expected_cost + scale * torch.mul(prob, cost).detach().numpy()
LAST_CACHE_SIZE[0] = count_cached_ops(cache)
return expected_cost