Hi there, I am trying to modify the prodLDA tutorial abit to try the autoguidelist + custom_guide, with a prior on the decoder weights instead of just a MLE estimate
Here is my code
def model(self, x, edge_index):
pyro.module("decoder", self.decoder)
w = pyro.sample(
"w",
dist.Normal(
torch.zeros((15, self.n_genes), device=x.device),
torch.ones((15, self.n_genes), device=x.device),
).to_event(2),
)
with pyro.plate("data", x.shape[0]):
logtheta_loc = x.new_zeros((x.shape[0], self.n_topics))
logtheta_scale = x.new_ones((x.shape[0], self.n_topics))
logtheta = pyro.sample(
"logtheta", dist.Normal(logtheta_loc, logtheta_scale).to_event(1)
)
theta = F.softmax(logtheta, dim=1)
# mean, rate = self.decoder(theta)
mean = torch.matmul(theta, w)
mean = F.softmax(mean, dim=-1)
# rate = torch.exp(rate)
library_size = torch.sum(x, -1, keepdim=True)
# total_count = int(x.sum(-1).max())
samples = pyro.sample(
"obs", dist.Poisson(library_size * mean).to_event(1), obs=x
)
def custom_guide(self, x, edge_index):
pyro.module("encoder", self.encoder)
# w_loc = pyro.param("w_loc", torch.zeros((15, 6000), device=x.device))
# w_scale = pyro.param("w_scale", torch.zeros((15, 6000), device=x.device))
# w_scale = torch.sqrt(torch.exp(w_scale))
# w = pyro.sample("w", dist.Normal(w_loc, w_scale).to_event(2))
with pyro.plate("data", x.shape[0]):
# linkx encoder
logtheta_loc, logtheta_cov, logtheta_diag = self.encoder(x, edge_index)
logtheta = pyro.sample(
"logtheta",
dist.LowRankMultivariateNormal(
logtheta_loc, logtheta_cov, logtheta_diag
).to_event(0),
)
return F.softmax(logtheta, dim=1)
def guide(self):
columns_for_laplace_distribution = ["logtheta"]
my_guide = AutoGuideList(self.model)
my_guide.append(self.custom_guide)
my_guide.append(AutoNormal(pyro.poutine.block(self.model, expose=["w"])))
return my_guide
I am experiencing this error but Im not sure on how to proceed.
RuntimeError: Multiple sample sites named ‘data’