Subsampling in a plate makes a program grind to a halt

Hi All,

I started experimenting with Pyro couple days ago and have been enjoining it.

I wrote a simple LDA and was able to estimate on a small simulated dataset with NUTS and infer={“enumerate”: “parallel”} on CUDA.

I then tried to use sub-sampling of words within documents and unfortunately found that the algorithm stopped working. By which I mean, it appears to do something but not a single iteration is complete after long time of waiting. Here is the code:

def model(M, T, D, W, Z):
  with pyro.plate("topics", T):
    topic_words = pyro.sample("beta",
                              dist.Dirichlet(0.1* torch.ones(W)))
  with pyro.plate("documents", D) as ind:
    data = M[:, ind]
    doc_topics = pyro.sample("doc_topics",
                             dist.Dirichlet(50/T * torch.ones(T)))
    with pyro.plate("words", Z, subsample_size = 10 ) as ind_w:
        data = data[ind_w,:]
        word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),
                                                    infer={"enumerate": "parallel"})
        pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]),

nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True) 
mcmc = MCMC(nuts_kernel, num_samples=1, warmup_steps=20),T,D,W, Z)

What can be a reason for why the above does not work, but if I drop the “subsample_size = 10” everything works as expected?

As a bonus, could you point me to a resources how to manage CUDA memory? With a larger data size I quickly run out of memory on google colab.

Thanks in advance!

MCMC does not support subsampling. I think the correct behavior in this case would have been to error.

Oh I see. Since it didn’t throw an error I was confused.