Elbo contribution of plates

I have a model with an outer plate around an interior model which uses enums.

I was wondering if there is a clever way to break out the TraceEnum_ELBO.loss contribution of each outer plate? I was considering adding another model parameter which would only run a particular outer plate instead of all of them, then making a new AutoGuide, and using it by looping through the outer plates… But this smells a little clunky, and is probably slow.

Is there a more elegant approach? I was hoping for something where I overload a part of the ELBO mechanism (maybe rewrite .loss() ?), but I don’t have a good handle on how that works. Any guidance on a general approach appreciated!

Well… I found using a poutine mask and looping exceptionally slow, but a very small diff gives me what I’m looking for. Probably not general enough for broad use, but throwing here in case it helps anyone. Diff from dev branch.

diff --git a/pyro/infer/traceenum_elbo.py b/pyro/infer/traceenum_elbo.py
index c3b094f7..f38e4ed4 100644
--- a/pyro/infer/traceenum_elbo.py
+++ b/pyro/infer/traceenum_elbo.py
@@ -120,7 +120,7 @@ def _compute_model_factors(model_trace, guide_trace):
     return marginal_costs, log_factors, ordering, enum_dims, scale

-def _compute_dice_elbo(model_trace, guide_trace):
+def _compute_dice_elbo(model_trace, guide_trace, breakout=False):
     # Accumulate marginal model costs.
     marginal_costs, log_factors, ordering, sum_dims, scale = _compute_model_factors(
             model_trace, guide_trace)
@@ -155,7 +155,7 @@ def _compute_dice_elbo(model_trace, guide_trace):
             cost = packed.neg(site["packed"]["log_prob"])
             costs.setdefault(ordering[name], []).append(cost)

-    return Dice(guide_trace, ordering).compute_expectation(costs)
+    return Dice(guide_trace, ordering).compute_expectation(costs, breakout)

 def _make_dist(dist_, logits):
diff --git a/pyro/infer/util.py b/pyro/infer/util.py
index f20a7ac8..74f2ae31 100644
--- a/pyro/infer/util.py
+++ b/pyro/infer/util.py
@@ -209,7 +209,7 @@ class Dice(object):

         return log_factors

-    def compute_expectation(self, costs):
+    def compute_expectation(self, costs, breakout=False):
         Returns a differentiable expected cost, summing over costs at given ordinals.

@@ -263,7 +263,10 @@ class Dice(object):
                         cost = cost[mask]
                         cost, prob = packed.broadcast_all(cost, prob)
-                    expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim())
+                    if not breakout:
+                        expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim())
+                    else:
+                        expected_cost = expected_cost + scale * torch.mul(prob, cost).detach().numpy()

         LAST_CACHE_SIZE[0] = count_cached_ops(cache)
         return expected_cost