Hi all,
I’m currently working on a model with 1 GPU and I am facing an OOM error.
All the chains are loaded in the GPU memory and we exceed the 16GB of the GPU.
To fix it, I have implemented batching. (Forum 1, Forum 2)
It now runs with 2000 samples. However BFMI is low → I need to increase the samples.
OOM error is again triggered with 4000 samples. It seems that 2 chains are always kept in memory. (1 chain being 6GB, 2 chains 12GB → OOM on 16GB GPU)
Below is the experiments I have done:
Experiments so far:
XLA_PYTHON_CLIENT_PREALLOCATE | XLA_PYTHON_CLIENT_ALLOCATOR | num_samples | num_warmup | Chain method | VRAM at beginning of each sample | Successfully run? |
---|---|---|---|---|---|---|
False | Platform | 4000 | 10000 | sequential | 1. 744 MB 2. 6566MB | No |
False | Platform | 2000 | 10000 | sequential | 1. 744 MB 2. 3695 MB 3. 3695 MB 4. 3695 MB | Yes |
Default | Default | 2000 | 10000 | sequential | 1. 13,385 MB 2. 13,385 MB | No |
Similar to other in the forums, the 2nd batch will usually cause an OOM error. I have attempted to pass mainly these two flags: XLA_PYTHON_CLIENT_PREALLOCATE
and XLA_PYTHON_CLIENT_ALLOCATOR
, as suggested in the forums I have mentioned.
When sampling for 2nd batch, the GPU memory is not all released. I assume this is the previous sampling state. It does not get released until the model ran.
Current implementation of the model and the batching:
Please toggle to see:
Implementations
Model definition and NUTS+MCMC configuration
def model(n_items, n_factors, n_persons, responses_mask, responses=None):
with plate("diff_dim1", n_items, dim=-1):
diff = numpyro.sample("diff", dist.Normal(loc=0.0, scale=1.0))
with plate("discrim_dim1", n_factors, dim=-1):
with plate("discrim_dim2", n_items, dim=-2):
discrim_offset = numpyro.sample(
"discrim_offset", dist.LogNormal(loc=0.0, scale=1.0)
)
# Need to use deterministic layer to apply q_matrix
discrim = numpyro.deterministic(
"discrim", discrim_offset * item_factors
)
with plate("ability_dim1", n_persons, dim=-1):
corr = numpyro.sample("ability_corr", dist.LKJ(n_factors, jnp.ones([1])))
# Mean would always be 0
mu = jnp.zeros(n_factors)
ability = numpyro.sample(
"ability", dist.MultivariateNormal(loc=mu, covariance_matrix=corr)
)
# Logistic regression
kernel = jnp.dot(ability, discrim.T) + diff
numpyro.sample(
"obs",
dist.BernoulliLogits(logits=kernel).mask(responses_mask),
obs=responses,
)
# Defining sampler
kernel = NUTS(
model, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False
)
mcmc = MCMC(
kernel,
num_warmup=10000,
num_samples=2000,
num_chains=1,
progress_bar=True,
chain_method="sequential",
jit_model_args=True,
)
Batching function
_traces = []
_extra_fields = []
for i in range(chains):
mcmc.run(
random.PRNGKey(0),
self.responses_mask,
self.item_factors,
self.response_trains,
extra_fields=(
"num_steps",
"potential_energy",
"energy",
"adapt_state.step_size",
"accept_prob",
"diverging",
),
)
# Transfer to CPU
samples = jax.device_put(
mcmc.get_samples(group_by_chain=True), jax.devices("cpu")[0]
)
extra_fields = jax.device_put(
mcmc.get_extra_fields(group_by_chain=True), jax.devices("cpu")[0]
)
_traces.append(samples)
_extra_fields.append(extra_fields)
del samples, extra_fields
# Set warmup state to the next run
sampler._warmup_state = sampler._last_state
gc.collect()
# Prepare the traces for arviz
trace = {}
extras = {}
for k in _traces[0].keys():
trace[k] = np.concatenate(list(trace[k] for trace in _traces))
for j in _extra_fields[0].keys():
extras[j] = np.concatenate(list(extras[j] for extras in _extra_fields))
idata = az.convert_to_inference_data(trace)
iextra = az.convert_to_inference_data(extras, group="sample_stats")
az.concat(idata, iextra, inplace=True)
Questions:
- Is there anyway to reduce the memory consumption to avoid OOM issue?
- Is there any suggestions on how to improve the situation? Or any mistake spotted in the implementation?
- Why I seem to have 2 chains at all times in memory, when I only keep the last state (which should be smaller)?
- Why do I run out of memory before hitting 16GB, is it fragmentation?
Thank you!