The line loss = svi.step(batch_docs, batch_doc_sum)
is triggering the error:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
I printed all tensors, they are on cuda device… Can someone help? Thanks!
Here’s the code:
# Loading the data, everything on cuda
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
with open('../input/prodlda/docs.pkl', "rb") as fp:
docs = pickle.load(fp)
doc_sum = torch.from_numpy(
pd.read_csv('../input/prodlda/doc_sum.csv', index_col=0).values
).float().to(device)
docs = [torch.from_numpy(doc).to(device) for doc in docs]
# When I create the model, I put it into cuda...
prodLDA = ProdLDA(
vocab_size=doc_sum.shape[1],
num_topics=20,
hidden=100,
dropout=0.2,
device=device
)
prodLDA.to(device)
# I'm putting every tensor on cuda... don't know what I'm missing...
class ProdLDA(nn.Module):
def __init__(self, vocab_size, num_topics, hidden, dropout, device):
super().__init__()
self.vocab_size = vocab_size
self.num_topics = num_topics
self.inference_net = Encoder(vocab_size, num_topics, hidden, dropout)
self.device = device
def model(self, docs=None, doc_sum=None):
# Globals.
with pyro.plate("topics", self.num_topics):
topic_weights = pyro.sample(
"topic_weights", dist.Gamma(1. / self.num_topics, 1.)
).to(self.device)
topic_words = pyro.sample(
"topic_words", dist.Dirichlet(torch.ones(self.vocab_size) / self.vocab_size)
).to(self.device)
# Locals.
for doc_idx in pyro.plate("documents", len(docs)):
doc_topics = pyro.sample("doc_topics_%d" % doc_idx,
dist.Dirichlet(topic_weights))
for word_idx in pyro.plate("words_%d" % doc_idx, len(docs[doc_idx])):
word_topics = pyro.sample(
"word_topics_%d_%d" % (doc_idx, word_idx),
dist.Categorical(doc_topics),
infer={"enumerate": "parallel"}
)
pyro.sample("doc_word_%d_%d" % (doc_idx, word_idx),
dist.Categorical(topic_words[word_topics]),
obs=docs[doc_idx][word_idx])
def guide(self, docs=None, doc_sum=None):
# Use a conjugate guide for global variables.
topic_weights_posterior = pyro.param(
"topic_weights_posterior",
lambda: torch.ones(self.num_topics).to(self.device),
constraint=constraints.positive)
topic_words_posterior = pyro.param(
"topic_words_posterior",
lambda: torch.ones(self.num_topics, self.vocab_size).to(self.device),
constraint=constraints.greater_than(0.5))
with pyro.plate("topics", self.num_topics):
pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.))
pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))
# Use an amortized guide for local variables.
pyro.module("inference_net", self.inference_net)
with pyro.plate("documents", doc_sum.shape[0]):
doc_topics = self.inference_net(doc_sum)
pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))
Thanks!