Multiple sites sample error for autoguide + custom_guide

Hi there, I am trying to modify the prodLDA tutorial abit to try the autoguidelist + custom_guide, with a prior on the decoder weights instead of just a MLE estimate

Here is my code

def model(self, x, edge_index):
        pyro.module("decoder", self.decoder)

        w = pyro.sample(
            "w",
            dist.Normal(
                torch.zeros((15, self.n_genes), device=x.device),
                torch.ones((15, self.n_genes), device=x.device),
            ).to_event(2),
        )

        with pyro.plate("data", x.shape[0]):
            logtheta_loc = x.new_zeros((x.shape[0], self.n_topics))
            logtheta_scale = x.new_ones((x.shape[0], self.n_topics))
            logtheta = pyro.sample(
                "logtheta", dist.Normal(logtheta_loc, logtheta_scale).to_event(1)
            )
            theta = F.softmax(logtheta, dim=1)
            # mean, rate = self.decoder(theta)
            mean = torch.matmul(theta, w)
            mean = F.softmax(mean, dim=-1)
            # rate = torch.exp(rate)
            library_size = torch.sum(x, -1, keepdim=True)
            # total_count = int(x.sum(-1).max())
            samples = pyro.sample(
                "obs", dist.Poisson(library_size * mean).to_event(1), obs=x
            )

    def custom_guide(self, x, edge_index):
        pyro.module("encoder", self.encoder)
        # w_loc = pyro.param("w_loc", torch.zeros((15, 6000), device=x.device))
        # w_scale = pyro.param("w_scale", torch.zeros((15, 6000), device=x.device))
        # w_scale = torch.sqrt(torch.exp(w_scale))
        # w = pyro.sample("w", dist.Normal(w_loc, w_scale).to_event(2))
        with pyro.plate("data", x.shape[0]):
            # linkx encoder
            logtheta_loc, logtheta_cov, logtheta_diag = self.encoder(x, edge_index)
            logtheta = pyro.sample(
                "logtheta",
                dist.LowRankMultivariateNormal(
                    logtheta_loc, logtheta_cov, logtheta_diag
                ).to_event(0),
            )

            return F.softmax(logtheta, dim=1)
    
    def guide(self):
        columns_for_laplace_distribution = ["logtheta"]
        my_guide = AutoGuideList(self.model)
        my_guide.append(self.custom_guide)
        my_guide.append(AutoNormal(pyro.poutine.block(self.model, expose=["w"])))
        return my_guide

I am experiencing this error but Im not sure on how to proceed.
RuntimeError: Multiple sample sites named ‘data’

when you call pyro.plate or pyro.sample the string containing the name, e.g. “data” in those two code chunks have to be unique:
with pyro.plate(“data”, x.shape[0]):
logtheta_loc = x.new_zeros((x.shape[0], self.n_topics))
and
def custom_guide(self, x, edge_index):
pyro.module(“encoder”, self.encoder)
# w_loc = pyro.param(“w_loc”, torch.zeros((15, 6000), device=x.device))
# w_scale = pyro.param(“w_scale”, torch.zeros((15, 6000), device=x.device))
# w_scale = torch.sqrt(torch.exp(w_scale))
# w = pyro.sample(“w”, dist.Normal(w_loc, w_scale).to_event(2))
with pyro.plate(“data”, x.shape[0]):

you can define your pyro.plate before that like so:
self.sample_plate = pyro.plate(“data”, x.shape[0])

and then call “with self.sample_plate:” in your functions