Subsampling not scaling to larger datasets, and a potential solution

I’ve been trying to identify why a model with a constant sub-sampling number is much slower with more data even through the number of parameters within the model are constant. I have a solution that I want to put out there in case it helps anyone, and I want to make sure I’m not stepping on any foot-guns.

import jax.numpy as jnp
from jax import random
import numpyro
from numpyro import distributions as dist
import jax
from numpyro.infer import SVI, Trace_ELBO,autoguide,TraceMeanField_ELBO

n_cats = 2

start_key = random.PRNGKey(100)
mu_key,std_key,prop_key,svi_key,post_pred_key,samp_key = random.split(start_key,6)
mus = random.normal(mu_key,shape=(n_cats,))
stds = random.lognormal(mu_key,shape=(n_cats,))
proportion = random.uniform(prop_key)
probs = jnp.hstack([proportion,1-proportion])
mix_dist = dist.MixtureSameFamily(mixing_distribution=dist.CategoricalProbs(probs=probs),component_distribution=dist.Normal(mus,stds))

targ_params = {'mu':mus,'stds':stds,'probs':probs}
display(targ_params)

#  {'mu': Array([0.87981087, 2.25013   ], dtype=float32),
#  'stds': Array([2.4104438, 9.488969 ], dtype=float32),
#  'probs': Array([0.45550573, 0.5444943 ], dtype=float32)}

def model(data):
    with numpyro.plate('cat_plate',n_cats):
        mus = numpyro.sample('mu',dist.Normal())
        stds = numpyro.sample('stds',dist.LogNormal())
    logit = numpyro.sample('logit',dist.Normal())
    probs = numpyro.deterministic('probs',jnp.hstack((jax.nn.sigmoid(logit),
                                                      jax.nn.sigmoid(-logit))))
    with numpyro.plate('obs_plate',data.shape[0],subsample_size=500):
        numpyro.sample('obs',dist.MixtureSameFamily(mixing_distribution=dist.CategoricalProbs(probs=probs),
                                                    component_distribution=dist.Normal(mus,stds)),
                       obs=numpyro.subsample(data,event_dim=0))

Here is the model with a subsample_size of 500 that we’re going to keep constant. Dataset 1 has 50,000 samples and dataset 2 has 5,000,000 samples. But in terms of what is happening in each epoch/step, the optimization process is dealing with the same number of parameters regardless of dataset size.

Here is dataset 1:

%%time

samps,samp_cat = mix_dist.sample_with_intermediates(random.PRNGKey(10),(50_000,))
samp_cat=samp_cat[0]

guide =autoguide.AutoNormal(model)
svi = SVI(guide.model,guide,numpyro.optim.Adam(5e-4),TraceMeanField_ELBO())

rkey=random.PRNGKey(0)
svi_state = svi.init(rng_key=rkey,data=samps)

@jax.jit
def body_fn(svi_state, _):
    svi_state, loss = svi.update(svi_state, data=samps)
    return svi_state, loss

svi_state, losses = jax.lax.scan(body_fn, svi_state, None, length=10_000)
display(guide.median(svi.get_params(svi_state)))

# {'mu': Array([1.8945895 , 0.89211637], dtype=float32),
#  'stds': Array([9.617037 , 2.4565895], dtype=float32),
#  'logit': Array(0.14584522, dtype=float32),
#  'probs': Array([0.5363968 , 0.46360317], dtype=float32)}

# CPU times: user 3.05 s, sys: 32.9 ms, total: 3.09 s
# Wall time: 2.99 s

And Dataset 2:

%%time

samps,samp_cat = mix_dist.sample_with_intermediates(random.PRNGKey(10),(5_000_000,))
samp_cat=samp_cat[0]

guide =autoguide.AutoNormal(model)
svi = SVI(guide.model,guide,numpyro.optim.Adam(5e-4),TraceMeanField_ELBO())

rkey=random.PRNGKey(0)
svi_state = svi.init(rng_key=rkey,data=samps)

@jax.jit
def body_fn(svi_state, _):
    svi_state, loss = svi.update(svi_state, data=samps)
    return svi_state, loss

svi_state, losses = jax.lax.scan(body_fn, svi_state, None, length=10_000)
display(guide.median(svi.get_params(svi_state)))

# {'mu': Array([1.8770313, 0.9183824], dtype=float32),
# 'stds': Array([9.509483 , 2.4253767], dtype=float32),
# 'logit': Array(0.18706162, dtype=float32),
#'probs': Array([0.5466295 , 0.45337048], dtype=float32)}

# CPU times: user 31.3 s, sys: 728 ms, total: 32 s
# Wall time: 11.5 s

So 100x the data pool takes about 3.8x as long.

However, if we take Dataset 2 and reshape it into 100 chunks the same size as Dataset 1, scan over them and scale the number of epochs by 100, we can train on Dataset 2 in the same time as Dataset 1.

%%time
samps,samp_cat = mix_dist.sample_with_intermediates(random.PRNGKey(10),(5_000_000,))
samp_cat=samp_cat[0]

%%time
guide =autoguide.AutoNormal(model)
svi = SVI(guide.model,guide,numpyro.optim.Adam(5e-4),TraceMeanField_ELBO())

# perm_idx = random.permutation(random.PRNGKey(1000),len(samps)).reshape(100,-1)
perm_idx=jnp.asarray(np.random.RandomState(0).permutation(len(samps))).reshape(100,-1)

rkey=random.PRNGKey(0)
svi_state = svi.init(rng_key=rkey,data=samps[perm_idx[0]])

def large_scan(svi_state,idx):
    @jax.jit
    def body_fn(svi_state,_ ):
        svi_state, loss = svi.update(svi_state, data=samps[idx])
        return svi_state, loss
        
    svi_state, losses = jax.lax.scan(body_fn, svi_state, None, length=100)
    return svi_state,losses


svi_state,lossL = jax.lax.scan(large_scan,svi_state,perm_idx)
losses=jnp.hstack(lossL)
display(guide.median(svi.get_params(svi_state)))

# {'mu': Array([1.8874398 , 0.90744805], dtype=float32),
# 'stds': Array([9.485589 , 2.4041011], dtype=float32),
# 'logit': Array(0.17867826, dtype=float32),
# 'probs': Array([0.5445511, 0.4554489], dtype=float32

# CPU times: user 3.21 s, sys: 59.8 ms, total: 3.27 s
# Wall time: 3.22 s

I used numpy’s permutation because it’s much faster than jax’s, which I know has come up a few times but you get the idea.

So am I doing anything wrong by using this approach? The large_scan approach should be functionally equivalent, right? Hope so because while this isn’t the exact problem that I’m trying to solve, I am working with some enormous datasets and I’d love to be able to capitalize on a 4x speed-up. Thanks in advance!

Your observation is reasonable. The slice x[subsample_idx] is expensive for large x. In practice, we can use dataloader to generate a batch of data (in numpy) per step, like in your code.