Extremely slow subsample

Hi numpyro-ers! Could somebody please help me understand why my SVI is running 40x slower when I use subsampling?

Below is a minimal example, inferring the mean (z) from a set of 10000 samples (y). Without subsampling, this code takes 0.8 seconds to run on my machine:

import jax
from jax import lax, numpy as jnp
import numpy as np
import numpyro
from numpyro import distributions as dist
from numpyro.infer import Trace_ELBO, SVI
import time

y = np.random.randn(10000)

def model(y):
    z = numpyro.sample("z", dist.Normal(0, 1))

    with numpyro.plate("data", len(y)):
        numpyro.sample("obs", dist.Normal(z, 1), obs=y)

def guide(y):
    numpyro.sample("z",
        dist.Normal(
            numpyro.param("mu_z", 0.),
            numpyro.param("sigma_z", 1., constaints=dist.constraints.positive)
        )
    )

#### Inference
rng_key = jax.random.PRNGKey(0)
optimizer = numpyro.optim.Adam(step_size=1e-4)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
init_state = svi.init(rng_key, y)

t = time.time()
state = lax.fori_loop(0, 5000, lambda i, state: svi.update(state, y)[0], init_state)
params = svi.get_params(state)
print(f"mu_z = {params['mu_z']:.4f}")
print(f"Took {time.time()-t:.4f} seconds")

# Took 0.8473 seconds

But when I change the model to include subsampling, it takes now takes 32 seconds to run

def model(y):
    z = numpyro.sample("z", dist.Normal(0, 1))

    with numpyro.plate("data", len(y), subsample_size=500):
        y = numpyro.subsample(y, event_dim=0)
        numpyro.sample("obs", dist.Normal(z, 1), obs=y)

# Took 31.8453 seconds

What’s happening here?
Big thanks! ~ Luke

@lbh Interesting! The reason is the cost to obtain a subsample dominates the cost to compute ELBO.

import jax

def f(i):
    return jax.random.permutation(jax.random.PRNGKey(i), 10000)[:500]

jf = jax.jit(f)
%timeit x = jf(0).copy()

return

5.38 ms ± 300 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

in my system. This is so slow comparing to PyTorch

%%timeit
import torch

x = torch.randperm(10000)[:500]

which returns 79.4 µs ± 418 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Could you submit this issue upstream to JAX devs and ping me there?

1 Like

@lbh subsample should be fast after this fix. :slight_smile: Thanks for raising the issue!