Hi numpyro-ers! Could somebody please help me understand why my SVI is running 40x slower when I use subsampling?
Below is a minimal example, inferring the mean (z
) from a set of 10000 samples (y
). Without subsampling, this code takes 0.8 seconds to run on my machine:
import jax
from jax import lax, numpy as jnp
import numpy as np
import numpyro
from numpyro import distributions as dist
from numpyro.infer import Trace_ELBO, SVI
import time
y = np.random.randn(10000)
def model(y):
z = numpyro.sample("z", dist.Normal(0, 1))
with numpyro.plate("data", len(y)):
numpyro.sample("obs", dist.Normal(z, 1), obs=y)
def guide(y):
numpyro.sample("z",
dist.Normal(
numpyro.param("mu_z", 0.),
numpyro.param("sigma_z", 1., constaints=dist.constraints.positive)
)
)
#### Inference
rng_key = jax.random.PRNGKey(0)
optimizer = numpyro.optim.Adam(step_size=1e-4)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
init_state = svi.init(rng_key, y)
t = time.time()
state = lax.fori_loop(0, 5000, lambda i, state: svi.update(state, y)[0], init_state)
params = svi.get_params(state)
print(f"mu_z = {params['mu_z']:.4f}")
print(f"Took {time.time()-t:.4f} seconds")
# Took 0.8473 seconds
But when I change the model to include subsampling, it takes now takes 32 seconds to run
def model(y):
z = numpyro.sample("z", dist.Normal(0, 1))
with numpyro.plate("data", len(y), subsample_size=500):
y = numpyro.subsample(y, event_dim=0)
numpyro.sample("obs", dist.Normal(z, 1), obs=y)
# Took 31.8453 seconds
What’s happening here?
Big thanks! ~ Luke