Pyro complains I'm using 2 devices (cuda & cpu), while all tensors are on cuda

The line loss = svi.step(batch_docs, batch_doc_sum) is triggering the error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I printed all tensors, they are on cuda device… Can someone help? Thanks!
Here’s the code:

# Loading the data, everything on cuda
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

with open('../input/prodlda/docs.pkl', "rb") as fp:
    docs = pickle.load(fp)

doc_sum = torch.from_numpy(
    pd.read_csv('../input/prodlda/doc_sum.csv', index_col=0).values
).float().to(device)
docs = [torch.from_numpy(doc).to(device) for doc in docs]
# When I create the model, I put it into cuda...
prodLDA = ProdLDA(
    vocab_size=doc_sum.shape[1],
    num_topics=20,
    hidden=100,
    dropout=0.2,
    device=device
)
prodLDA.to(device)
# I'm putting every tensor on cuda... don't know what I'm missing...
class ProdLDA(nn.Module):
    def __init__(self, vocab_size, num_topics, hidden, dropout, device):
        super().__init__()
        self.vocab_size = vocab_size
        self.num_topics = num_topics
        self.inference_net = Encoder(vocab_size, num_topics, hidden, dropout)
        self.device = device

    def model(self, docs=None, doc_sum=None):
        # Globals.
        with pyro.plate("topics", self.num_topics):
            topic_weights = pyro.sample(
                "topic_weights", dist.Gamma(1. / self.num_topics, 1.)
            ).to(self.device)
            topic_words = pyro.sample(
                "topic_words", dist.Dirichlet(torch.ones(self.vocab_size) / self.vocab_size)
            ).to(self.device)

        # Locals.
        for doc_idx in pyro.plate("documents", len(docs)):
            doc_topics = pyro.sample("doc_topics_%d" % doc_idx,
                                     dist.Dirichlet(topic_weights))
            for word_idx in pyro.plate("words_%d" % doc_idx, len(docs[doc_idx])):
                word_topics = pyro.sample(
                    "word_topics_%d_%d" % (doc_idx, word_idx),
                    dist.Categorical(doc_topics),
                    infer={"enumerate": "parallel"}
                )
                pyro.sample("doc_word_%d_%d" % (doc_idx, word_idx),
                            dist.Categorical(topic_words[word_topics]),
                            obs=docs[doc_idx][word_idx])

    def guide(self, docs=None, doc_sum=None):
        # Use a conjugate guide for global variables.
        topic_weights_posterior = pyro.param(
            "topic_weights_posterior",
            lambda: torch.ones(self.num_topics).to(self.device),
            constraint=constraints.positive)
        
        topic_words_posterior = pyro.param(
            "topic_words_posterior",
            lambda: torch.ones(self.num_topics, self.vocab_size).to(self.device),
            constraint=constraints.greater_than(0.5))
        
        with pyro.plate("topics", self.num_topics):
            pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.))
            pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

        # Use an amortized guide for local variables.
        pyro.module("inference_net", self.inference_net)
        with pyro.plate("documents", doc_sum.shape[0]):
            doc_topics = self.inference_net(doc_sum)
            pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))

Thanks!

i think you need to make sure all the params of distributions are cuda. there’s no need to move the samples. in particular replace torch.ones() with something like prototype_tensor.new_ones()

1 Like

Yes, it worked! Thanks! Curiously moving samples to cuda doesn’t work… I had to do:

a = torch.tensor(1. / self.num_topics, device=self.device)
b = torch.tensor(1., device=self.device)
topic_weights = pyro.sample("topic_weights", dist.Gamma(a, b))
            
alpha = torch.ones(self.vocab_size, device=self.device) / self.vocab_size
topic_words = pyro.sample("topic_words", dist.Dirichlet(alpha))

That fixed the issue… thanks!!