Using Nested Enumerate without Exploding Memory

Hi there,
In the course of building SVI models, I keep running into problems with enumerate expanding dimensions and running out of memory when in parallel mode.

The basic problem arises in the following obvious scenario:

batch_size=32
latent_dim=1000
probs1=torch.rand(batch_size,100)
probs2=torch.rand(batch_size,100)
probs3=torch.rand(batch_size,100)
locs=torch.ones(probs1.shape[1],probs2.shape[1],probs3.shape[1],latent_dim)
cat1=dist.Categorical(probs1).to_event(1)
discrete_out1=pyro.sample('cat1',cat1,infer={'enumerate':'parallel'})
cat2=dist.Categorical(probs2).to_event(1)
discrete_out2=pyro.sample('cat2',cat2,infer={'enumerate':'parallel'})
cat3=dist.Categorical(probs3).to_event(1)
discrete_out3=pyro.sample('cat3',cat3,infer={'enumerate':'parallel'})
locs[discrete_out1,discrete_out2,discrete_out3]

where locs is expanded to [i,j,k, batch_size,latent_dim] during enumeration, which triggers CUDA to run out of memory.

My understanding is that as I sample discretely and parallelize it, it expands the discrete dimensions and this allows it to be processed very quickly. However this expansion for a hierarchical model is “i x j x k x latent_dim” in memory, and so on even a large GPU I quickly run out of RAM.

The best solution I’ve come up with is by using parallel for as many sites as I can, and then sequential, but this is EXTREMELY slow as it seems that sequential is >100x slower than parallel, for each discrete site. Usually it’s faster to train the model using a cpu rather than to use the ‘sequential’ mode of enumeration.

I’m wondering is there any way to get some of the speed of the parallel enumeration without the exploding memory problem?

Thanks!!!
Matthew

Side note:
As a compromise for speed and memory I’ve been getting around the problem by approximating the nested indexing by using einsum with the probabilities instead of the discrete sample, but this is a rather sloppy solution.

At a first glance I can say that you don’t need .to_event(1) for Categorical distributions in your case since they seem to be conditionally independent along the batch_size dimension. You can declare this using pyro.plate primitive (explained more here):

cat1=dist.Categorical(probs1)
...
with pyro.plate("batch_dim", size=batch_size):
    discrete_out1 = pyro.sample("cat1", cat1, infer={'enumerate':'parallel'})
    ...
    locs[discrete_out1, discrete_out2, discrete_out3] # be careful here

I would also be careful with the last line and make sure that locs is indexed properly using Pyro’s Vindex helper tool.

Hope this helps.

Sorry about those bugs, I just wrote up a quick toy example instead of the actual model.

I guess my question remains though since both:

    locs[discrete_out1, discrete_out2, discrete_out3] # be careful here

and

    Vindex(locs)[discrete_out1, discrete_out2, discrete_out3] # be careful here

suffer from the memory exploding problem since both are expanded from [32,1000] to [100,100,100,32,1000] which inevitably triggers “out of memory.”

Do you mean sequential enumeration of discrete sites here? Have you tried to use parallel enumeration for all discrete sites and sequential version of the pyro.plate (for i in pyro.plate(...):)? In this case after indexing locs will expand to [100,100,100,1000] and the for loop will have only 32 iterations which might be reasonably fast?

Maybe other people on the forum have suggestions too.

Yeah I mean that I use parallel enumeration of some discrete sites, and sequential for others. Looping through the batch is a reasonable idea and would help the memory problem somewhat, but would decrease the performance for other elements of the model that I’ve been using the vectorized batch plate on.

I really just want the performance of the parallel enumeration without the transient expansion of the enumerated variables ending up too large, but I’m not sure if that’s possible :slight_smile: