Different Sizes of GMM Observation with SVI vs Custom Loss

I used the GMM tutorial to put together a working program. The mixture model fit to my data, worked great. Now I want to add a custom loss component, so I separated the ELBO loss function and am using a manual loss loop, as described in the Custom SVI tutorial.

When I originally used the ELBO object with the abstracted SVI training loop, my observation were of size (num_clusters, num_points, num_features). With this added custom loss, I’m using a trace call on the guide, and the observation is (num_points, num_features). My intent was to get the points assigned to each cluster and calculate some loss on them. What is the reason I’m getting different sizes for observation through the different approaches, and how do I get the original size of (num_clusters, num_points, num_features)?

def mutual_info_loss(self):
        cluster_assignments = self.get_classes()
        guide_trace = poutine.trace(self.guide).get_trace(self.data)
        # k_means = guide_trace.nodes['cluster_means']['value']
        k_covars = guide_trace.nodes['cluster_covars']['value']
        points = self.obs[cluster_assignments]
        print(points.shape)

def get_classes(self, temperature=0.0):
        guide_trace = poutine.trace(self.guide).get_trace(self.data)
        current_model = poutine.replay(self.model, trace=guide_trace)

        inferred_model = infer_discrete(
            current_model, temperature=temperature, first_available_dim=-2
        )
        trace = poutine.trace(inferred_model).get_trace(self.data)
        return trace.nodes['cluster_assignments']['value']

Thanks in advance!