Hi there,
In the course of building SVI models, I keep running into problems with enumerate expanding dimensions and running out of memory when in parallel mode.
The basic problem arises in the following obvious scenario:
batch_size=32
latent_dim=1000
probs1=torch.rand(batch_size,100)
probs2=torch.rand(batch_size,100)
probs3=torch.rand(batch_size,100)
locs=torch.ones(probs1.shape[1],probs2.shape[1],probs3.shape[1],latent_dim)
cat1=dist.Categorical(probs1).to_event(1)
discrete_out1=pyro.sample('cat1',cat1,infer={'enumerate':'parallel'})
cat2=dist.Categorical(probs2).to_event(1)
discrete_out2=pyro.sample('cat2',cat2,infer={'enumerate':'parallel'})
cat3=dist.Categorical(probs3).to_event(1)
discrete_out3=pyro.sample('cat3',cat3,infer={'enumerate':'parallel'})
locs[discrete_out1,discrete_out2,discrete_out3]
where locs is expanded to [i,j,k, batch_size,latent_dim] during enumeration, which triggers CUDA to run out of memory.
My understanding is that as I sample discretely and parallelize it, it expands the discrete dimensions and this allows it to be processed very quickly. However this expansion for a hierarchical model is “i x j x k x latent_dim” in memory, and so on even a large GPU I quickly run out of RAM.
The best solution I’ve come up with is by using parallel for as many sites as I can, and then sequential, but this is EXTREMELY slow as it seems that sequential is >100x slower than parallel, for each discrete site. Usually it’s faster to train the model using a cpu rather than to use the ‘sequential’ mode of enumeration.
I’m wondering is there any way to get some of the speed of the parallel enumeration without the exploding memory problem?
Thanks!!!
Matthew
Side note:
As a compromise for speed and memory I’ve been getting around the problem by approximating the nested indexing by using einsum with the probabilities instead of the discrete sample, but this is a rather sloppy solution.