 # Extremely slow subsample

Hi numpyro-ers! Could somebody please help me understand why my SVI is running 40x slower when I use subsampling?

Below is a minimal example, inferring the mean (`z`) from a set of 10000 samples (`y`). Without subsampling, this code takes 0.8 seconds to run on my machine:

``````import jax
from jax import lax, numpy as jnp
import numpy as np
import numpyro
from numpyro import distributions as dist
from numpyro.infer import Trace_ELBO, SVI
import time

y = np.random.randn(10000)

def model(y):
z = numpyro.sample("z", dist.Normal(0, 1))

with numpyro.plate("data", len(y)):
numpyro.sample("obs", dist.Normal(z, 1), obs=y)

def guide(y):
numpyro.sample("z",
dist.Normal(
numpyro.param("mu_z", 0.),
numpyro.param("sigma_z", 1., constaints=dist.constraints.positive)
)
)

#### Inference
rng_key = jax.random.PRNGKey(0)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
init_state = svi.init(rng_key, y)

t = time.time()
state = lax.fori_loop(0, 5000, lambda i, state: svi.update(state, y), init_state)
params = svi.get_params(state)
print(f"mu_z = {params['mu_z']:.4f}")
print(f"Took {time.time()-t:.4f} seconds")

# Took 0.8473 seconds
``````

But when I change the model to include subsampling, it takes now takes 32 seconds to run

``````def model(y):
z = numpyro.sample("z", dist.Normal(0, 1))

with numpyro.plate("data", len(y), subsample_size=500):
y = numpyro.subsample(y, event_dim=0)
numpyro.sample("obs", dist.Normal(z, 1), obs=y)

# Took 31.8453 seconds
``````

What’s happening here?
Big thanks! ~ Luke

@lbh Interesting! The reason is the cost to obtain a subsample dominates the cost to compute ELBO.

``````import jax

def f(i):
return jax.random.permutation(jax.random.PRNGKey(i), 10000)[:500]

jf = jax.jit(f)
%timeit x = jf(0).copy()
``````

return

``````5.38 ms ± 300 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
``````

in my system. This is so slow comparing to PyTorch

``````%%timeit
import torch

x = torch.randperm(10000)[:500]
``````

which returns `79.4 µs ± 418 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)`

Could you submit this issue upstream to JAX devs and ping me there?

1 Like

@lbh `subsample` should be fast after this fix. Thanks for raising the issue!