Help Needed: Hierarchical Model with crossed structure

okay so I am still getting issues with GPU memory usage.
I saw that in this thread you recommended using SVI and to move data between the CPU and GPU in batches.

so I am trying this.


guide = autoguide.AutoLowRankMultivariateNormal(logistic_model)
optimizer = numpyro.optim.Adam(0.0005)
svi = SVI(logistic_model, guide, optimizer, Trace_ELBO())
batch_size = 50000 #50k
batch_data = generate_random_batch_data(prepared_data,batch_size=batch_size)

svi_state = svi.init(
    rng_key, 
**batch_data
)

for _ in tqdm(range(1000)):
    batch_data = generate_random_batch_data(prepared_data,batch_size=batch_size)
    svi_state, loss  = svi.update(svi_state, 
    **batch_data
    )
    
params = svi.get_params(svi_state)
param_quantiles = guide.quantiles(params, [0.025, 0.5, 0.975])

but I am still getting memory issues, is there anything that you can recommend to assist with reducing GPU memory usage?

Traceback (most recent call last):
  File "main.py", line 86, in <module>
    svi_state = svi.init(
  File "/root/.local/lib/python3.8/site-packages/numpyro/infer/svi.py", line 180, in init
    guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
  File "/root/.local/lib/python3.8/site-packages/numpyro/handlers.py", line 171, in get_trace
    self(*args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/numpyro/infer/autoguide.py", line 559, in __call__
    latent = self._sample_latent(*args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/numpyro/infer/autoguide.py", line 547, in _sample_latent
    posterior = self._get_posterior()
  File "/root/.local/lib/python3.8/site-packages/numpyro/infer/autoguide.py", line 1035, in _get_posterior
    cov_factor = cov_factor * scale[..., None]
  File "/root/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 6784, in deferring_binary_op
    return binary_op(self, other)
  File "/root/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/jax/_src/api.py", line 427, in cache_miss
    out_flat = xla.xla_call(
  File "/root/.local/lib/python3.8/site-packages/jax/core.py", line 1690, in bind
    return call_bind(self, fun, *args, **params)
  File "/root/.local/lib/python3.8/site-packages/jax/core.py", line 1702, in call_bind
    outs = top_trace.process_call(primitive, fun, tracers, params)
  File "/root/.local/lib/python3.8/site-packages/jax/core.py", line 601, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/root/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 145, in _xla_call_impl
    out = compiled_fun(*args)
  File "/root/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 444, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 13042071472 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   12.15GiB
              constant allocation:         0B
        maybe_live_out allocation:   12.15GiB
     preallocated temp allocation:         0B
                 total allocation:   24.30GiB
              total fragmentation:         0B (0.00%