okay so I am still getting issues with GPU memory usage.
I saw that in this thread you recommended using SVI and to move data between the CPU and GPU in batches.
so I am trying this.
guide = autoguide.AutoLowRankMultivariateNormal(logistic_model)
optimizer = numpyro.optim.Adam(0.0005)
svi = SVI(logistic_model, guide, optimizer, Trace_ELBO())
batch_size = 50000 #50k
batch_data = generate_random_batch_data(prepared_data,batch_size=batch_size)
svi_state = svi.init(
rng_key,
**batch_data
)
for _ in tqdm(range(1000)):
batch_data = generate_random_batch_data(prepared_data,batch_size=batch_size)
svi_state, loss = svi.update(svi_state,
**batch_data
)
params = svi.get_params(svi_state)
param_quantiles = guide.quantiles(params, [0.025, 0.5, 0.975])
but I am still getting memory issues, is there anything that you can recommend to assist with reducing GPU memory usage?
Traceback (most recent call last):
File "main.py", line 86, in <module>
svi_state = svi.init(
File "/root/.local/lib/python3.8/site-packages/numpyro/infer/svi.py", line 180, in init
guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
File "/root/.local/lib/python3.8/site-packages/numpyro/handlers.py", line 171, in get_trace
self(*args, **kwargs)
File "/root/.local/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/root/.local/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/root/.local/lib/python3.8/site-packages/numpyro/infer/autoguide.py", line 559, in __call__
latent = self._sample_latent(*args, **kwargs)
File "/root/.local/lib/python3.8/site-packages/numpyro/infer/autoguide.py", line 547, in _sample_latent
posterior = self._get_posterior()
File "/root/.local/lib/python3.8/site-packages/numpyro/infer/autoguide.py", line 1035, in _get_posterior
cov_factor = cov_factor * scale[..., None]
File "/root/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 6784, in deferring_binary_op
return binary_op(self, other)
File "/root/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/root/.local/lib/python3.8/site-packages/jax/_src/api.py", line 427, in cache_miss
out_flat = xla.xla_call(
File "/root/.local/lib/python3.8/site-packages/jax/core.py", line 1690, in bind
return call_bind(self, fun, *args, **params)
File "/root/.local/lib/python3.8/site-packages/jax/core.py", line 1702, in call_bind
outs = top_trace.process_call(primitive, fun, tracers, params)
File "/root/.local/lib/python3.8/site-packages/jax/core.py", line 601, in process_call
return primitive.impl(f, *tracers, **params)
File "/root/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 145, in _xla_call_impl
out = compiled_fun(*args)
File "/root/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 444, in _execute_compiled
out_bufs = compiled.execute(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 13042071472 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 12.15GiB
constant allocation: 0B
maybe_live_out allocation: 12.15GiB
preallocated temp allocation: 0B
total allocation: 24.30GiB
total fragmentation: 0B (0.00%