Help Needed: Hierarchical Model with crossed structure

Hello, I have been following the bayesian hierarchical tutorial here and this forum question here in trying to solve my problem but I cannot seem to figure out what I am doing wrong and I figure someone here would probably be able to figure it out.

For my example, let us assume that I am trying to model whether a student applies for a spot at a university. My data generating process looks like this.

import uuid
import numpy  as np
import pandas as pd
from scipy.special import expit, logit


treatment_effect_magnitude = 0
n_students = 200
n_universities = 10
prob_treat = 0.5
intercept = -0.25

n_students = int(n_students)
n_universities = int(n_universities)

# create a list of students
student_ids = [str(uuid.uuid4()) for _ in range(n_students)]

# create a list of universities
uni_ids = [str(uuid.uuid4()) for _ in range(n_universities)]

# mu_i is the affinity of student i to apply for any uni
mu_i = np.random.normal(0, 0.2, n_students) 

# zeta_j is the affinity of uni j to receive an apply from any student
zeta_j = np.random.normal(0, 0.1, n_universities) 

# assign users to a flight
treatment = np.random.binomial(1, prob_treat, n_students)

data = []

for student_i, student_id in enumerate(student_ids):
    for uni_j, uni_id in enumerate(uni_ids):
        
        # generate linear predictor 
        eta = intercept + treatment_effect_magnitude * treatment[student_i] + mu_i[student_i] + zeta_j[uni_j]
        
        # generate the probability that the student will apply for uni j
        prob_apply = expit(eta)

        # generate outcome
        y = np.random.binomial(1, prob_apply, 1)[0]

        # store row of data for later model fitting
        data.append(dict(
            y = y,
            treatment = 'control' if treatment[student_i] == 0 else 'treatment',
            student_id = student_id,
            uni_id = uni_id,
        ))

data = pd.DataFrame(data)

I can then fit it using bambi and it looks like the correct answer, like so.

import bambi
import arviz as az
formula = 'y ~ 0 + treatment + (1|uni_id) + (1|student_id)'

model = bambi.Model(formula, data, family = 'bernoulli')

# build graph
model.build()
model.graph()

fitted_model= model.fit( 
    method = 'mcmc',
    draws = 1000,
    tune  = 1000,
    cores = 4,
    chains = 4,
    include_mean = False,
    omit_offsets  = False
)

# cut down list of variables summary
az.summary(fitted_model, var_names = ['treatment','1|uni_id_sigma','1|student_id_sigma'])

# full list of variables 
az.summary(fitted_model)

but when I try to fit with numpyro I am getting very different results and I think I am doing something wrong. I suspect I have messed something up with specifying the random component or I am encoding the variables incorrectly. If you could help it would be appreciated.My model is below.

import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
from jax import random
from sklearn.preprocessing import LabelEncoder
assert numpyro.__version__.startswith("0.9.0")

# define model
def numpyro_model(student_id, uni_id, treatment_id, y_obs):
    sigma_student = numpyro.sample("sigma_student", dist.HalfNormal(5))
    sigma_uni = numpyro.sample("sigma_uni", dist.HalfNormal(5))

    unique_student_ids = np.unique(student_id)
    n_students = len(unique_student_ids)

    unique_uni_IDs = np.unique(uni_id)
    n_unis = len(unique_uni_IDs)

    unique_treatment_IDs = np.unique(treatment_id)
    n_treatment = len(unique_treatment_IDs)

    with numpyro.plate("plate_i", n_students):
        offset_student = numpyro.sample("offset_student", dist.Normal(0.0, 1.0))
        alpha_student_i = offset_student * sigma_student

    with numpyro.plate("plate_j", n_unis):
        offset_uni = numpyro.sample("offset_uni", dist.Normal(0.0, 1.0))
        beta_uni_j = offset_uni * sigma_uni

    with numpyro.plate("plate_f", n_treatment):
        treatment_k = numpyro.sample("treatment_k", dist.Normal(0, 5))

    eta = alpha_student_i[student_id] + beta_uni_j[uni_id] + treatment_k[treatment_id]

    with numpyro.plate("data", len(y_obs)):
        y_est = numpyro.sample("y_est",  dist.Bernoulli(logits=eta), obs=y_obs)


# run model 
le = LabelEncoder()
data["student_id_encoded"] = le.fit_transform(data["student_id"].values)
data["uni_id_encoded"] = le.fit_transform(data["uni_id"].values)
data["treatment_id_encoded"] = le.fit_transform(data["treatment"].values)

y_obs = data["y"].values
uni_id_encoded = data["uni_id_encoded"].values
student_id_encoded = data["student_id_encoded"].values
treatment_id_encoded = data["treatment_id_encoded"].values

nuts_kernel = NUTS(numpyro_model)

mcmc = MCMC(nuts_kernel, num_samples=3000, num_warmup=3000, num_chains = 4)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, student_id_encoded, uni_id_encoded, treatment_id_encoded,y_obs)

posterior_samples = mcmc.get_samples()

# summary of results.

import arviz as az
data_az = az.from_numpyro(mcmc)
az.summary(data_az)

so if you could help that would be appreciated. I am also trying to figure out how I can increase the speed of the fitting process which is why I am trying to get it to work on numpyro instead of bambi (pymc3). It is okay at the example data size but when I am working with my real data set it is extremely slow, so if you have any tips for speed increases that would also be appreciated.

okay I found the issue, it was a silly mistake I multiplied when I should have +.

As a side note I cannot for the life of me get this to run on multiple CPUs at once. My computer has 8 cpus and I have tried setting numpyro.set_host_device_count(4) and it just ignores the command and runs on one CPU. I thought it might be because I am running WSL and it is getting confused but when I place the code into a docker image and run it in AWS. It has the same issue.

I feel like I must be missing something obvious here.

Did you call that statement at the beginning of your program?

ok cool, doing it at the beginning worked.
a couple of follow up questions.

  1. if I set numpyro.set_host_device_count() do you normally just set it to the number of chains that you have or does it depend on the chain method selected? for example, if the chain method is parallel I suspect setting set_host_device_count higher than the number of chains won’t do anything. How about if I set the chain method to vectorized and set the numpyro.set_host_device_count() larger than the number of chains will that utilise the excess cpus? Or do I have to do something like this method to utilise the extra cpus.
def do_mcmc(rng_key, n_vectorized=8):
    nuts_kernel = NUTS(model)
    mcmc = MCMC(
        nuts_kernel,
        progress_bar=False,
        num_chains=n_vectorized,
        chain_method='vectorized'
    )
    mcmc.run(
        rng_key,
        extra_fields=("potential_energy",),
    )
    return {**mcmc.get_samples(), **mcmc.get_extra_fields()}
# Number of devices to pmap over
n_parallel = jax.local_device_count()
rng_keys = jax.random.split(PRNGKey(rng_seed), n_parallel)
traces = pmap(do_mcmc)(rng_keys)
# concatenate traces along pmap'ed axis
trace = {k: np.concatenate(v) for k, v in traces.items()}

Is there any way to utilise both the GPU and CPU at the same time?

It is independent of the number of devices to use. This uses the same computation resource as if you call jnp.sum(some_vector_arrays) (i.e. use 1 device).

Is there any way to utilise both the GPU and CPU at the same time?

It depends. You can use JAX’s host_callback for some CPU calls in a GPU program. Typically you can’t run 1 chain on GPU and another chain on CPU on a single process.

okay just to confirm that I understand correctly are you saying that the vectorised method already uses all the cpus inside the device so trying to split it up would be pointless and just add additional overhead.

okay so utilising the GPU and CPU for different chains is not going to work.

as a further question, is there much to be gained by trying to parallelise or vectorise mcmc algorithms? Does the markov chain restrict the utility here?

If your model is small, I think using parallel method on CPU is best. If your model requires heavy computations, you can use GPU. With GPU, you can also use vectorized method with a large number of chains (to take advantage of GPU). A drawback of vectorized method is if the chains are not mixing well, things might be slow because in each MCMC step, we need to wait for all the chains to finish before moving to the next step.

the vectorised method already uses all the cpus inside the device

I’m not sure. Vectorized method behaves like jax.vmap(fun) where fun is 1 MCMC step. Not sure if it will use all of your CPUs. It might depend on your JAX configs, flags,…

great thank you.
Another question though.

The number of observations that I have is very large. When I try to set the chain method to vectorised when I am using the GPU I get the following error.

I assume I am running out of memory on the GPU is there anything that I can do to work around this apart from getting a GPU with more memory?

  0%|          | 0/2000 [00:00<?, ?it/s]
warmup:   0%|          | 1/2000 [00:08<4:46:23,  8.60s/it]2022-03-07 08:32:10.497444: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.17GiB (rounded to 5553072128)requested by op 
2022-03-07 08:32:10.497852: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] ********************************************************************************____________________
2022-03-07 08:32:10.498860: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2124] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 5553072000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    5.21GiB
              constant allocation:   48.51MiB
        maybe_live_out allocation:    5.21GiB
     preallocated temp allocation:  355.35MiB
  preallocated temp fragmentation:    11.4KiB (0.00%)
                 total allocation:   10.81GiB
              total fragmentation:  403.86MiB (3.65%)
Peak buffers:
	Buffer 1:
		Size: 5.17GiB
		Entry Parameter Subshape: f32[1000,1388268]
		==========================

	Buffer 2:
		Size: 5.17GiB
		Operator: op_name="jit(_body_fn)/jit(main)/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.PROMISE_IN_BOUNDS]" source_file="/root/.local/lib/python3.8/site-packages/numpyro/util.py" source_line=328
		XLA Label: fusion
		Shape: f32[1000,1388268]
		==========================

	Buffer 3:
		Size: 52.96MiB
		Operator: op_name="jit(_body_fn)/jit(main)/body/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1,), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.PROMISE_IN_BOUNDS]" source_file="/root/.local/lib/python3.8/site-packages/numpyro/infer/hmc_util.py" source_line=1042
		XLA Label: fusion
		Shape: f32[4,10,347066]
		==========================

	Buffer 4:
		Size: 52.96MiB
		Operator: op_name="jit(_body_fn)/jit(main)/body/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1,), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.PROMISE_IN_BOUNDS]" source_file="/root/.local/lib/python3.8/site-packages/numpyro/infer/hmc_util.py" source_line=1042
		XLA Label: fusion
		Shape: f32[4,10,347066]
		==========================

	Buffer 5:
		Size: 52.96MiB
		Operator: op_name="jit(_body_fn)/jit(main)/body/while[cond_nconsts=1 body_nconsts=9]" source_file="/root/.local/lib/python3.8/site-packages/numpyro/util.py" source_line=131
		XLA Label: fusion
		Shape: f32[4,10,347066]
		==========================

	Buffer 6:
		Size: 52.96MiB
		Operator: op_name="jit(_body_fn)/jit(main)/body/while[cond_nconsts=1 body_nconsts=9]" source_file="/root/.local/lib/python3.8/site-packages/numpyro/util.py" source_line=131
		XLA Label: fusion
		Shape: f32[4,10,347066]
		==========================

	Buffer 7:
		Size: 52.96MiB
		Operator: op_name="jit(_body_fn)/jit(main)/body/while[cond_nconsts=1 body_nconsts=9]" source_file="/root/.local/lib/python3.8/site-packages/numpyro/util.py" source_line=131
		XLA Label: fusion
		Shape: f32[4,10,347066]
		==========================

	Buffer 8:
		Size: 9.70MiB
		XLA Label: constant
		Shape: f32[2543309]
		==========================

	Buffer 9:
		Size: 9.70MiB
		XLA Label: constant
		Shape: f32[2543309]
		==========================

	Buffer 10:
		Size: 9.70MiB
		XLA Label: constant
		Shape: s32[2543309,1]
		==========================

	Buffer 11:
		Size: 9.70MiB
		XLA Label: constant
		Shape: s32[2543309,1]
		==========================

	Buffer 12:
		Size: 9.70MiB
		XLA Label: constant
		Shape: s32[2543309,1]
		==========================

	Buffer 13:
		Size: 5.29MiB
		Entry Parameter Subshape: f32[4,347066]
		==========================

	Buffer 14:
		Size: 5.29MiB
		Entry Parameter Subshape: f32[4,347066]
		==========================

	Buffer 15:
		Size: 5.29MiB
		Entry Parameter Subshape: f32[4,347066]
		==========================



warmup:   0%|          | 1/2000 [00:18<10:19:37, 18.60s/it]
Traceback (most recent call last):
  File "main.py", line 88, in <module>
    mcmc.run(
  File "/root/.local/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 608, in run
    states, last_state = partial_map_fn(map_args)
  File "/root/.local/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 410, in _single_chain_mcmc
    collect_vals = fori_collect(
  File "/root/.local/lib/python3.8/site-packages/numpyro/util.py", line 358, in fori_collect
    vals = jit(_body_fn)(i, vals)
ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 5553072000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    5.21GiB
              constant allocation:   48.51MiB
        maybe_live_out allocation:    5.21GiB
     preallocated temp allocation:  355.35MiB
  preallocated temp fragmentation:    11.4KiB (0.00%)
                 total allocation:   10.81GiB
              total fragmentation:  403.86MiB (3.65%)
Peak buffers:
	Buffer 1:
		Size: 5.17GiB
		Entry Parameter Subshape: f32[1000,1388268]
		==========================

	Buffer 2:
		Size: 5.17GiB
		Operator: op_name="jit(_body_fn)/jit(main)/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.PROMISE_IN_BOUNDS]" source_file="/root/.local/lib/python3.8/site-packages/numpyro/util.py" source_line=328
		XLA Label: fusion
		Shape: f32[1000,1388268]
		==========================

	Buffer 3:
		Size: 52.96MiB
		Operator: op_name="jit(_body_fn)/jit(main)/body/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1,), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.PROMISE_IN_BOUNDS]" source_file="/root/.local/lib/python3.8/site-packages/numpyro/infer/hmc_util.py" source_line=1042
		XLA Label: fusion
		Shape: f32[4,10,347066]
		==========================

	Buffer 4:
		Size: 52.96MiB
		Operator: op_name="jit(_body_fn)/jit(main)/body/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1,), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.PROMISE_IN_BOUNDS]" source_file="/root/.local/lib/python3.8/site-packages/numpyro/infer/hmc_util.py" source_line=1042
		XLA Label: fusion
		Shape: f32[4,10,347066]
		==========================

	Buffer 5:
		Size: 52.96MiB
		Operator: op_name="jit(_body_fn)/jit(main)/body/while[cond_nconsts=1 body_nconsts=9]" source_file="/root/.local/lib/python3.8/site-packages/numpyro/util.py" source_line=131
		XLA Label: fusion
		Shape: f32[4,10,347066]
		==========================

	Buffer 6:
		Size: 52.96MiB
		Operator: op_name="jit(_body_fn)/jit(main)/body/while[cond_nconsts=1 body_nconsts=9]" source_file="/root/.local/lib/python3.8/site-packages/numpyro/util.py" source_line=131
		XLA Label: fusion
		Shape: f32[4,10,347066]
		==========================

	Buffer 7:
		Size: 52.96MiB
		Operator: op_name="jit(_body_fn)/jit(main)/body/while[cond_nconsts=1 body_nconsts=9]" source_file="/root/.local/lib/python3.8/site-packages/numpyro/util.py" source_line=131
		XLA Label: fusion
		Shape: f32[4,10,347066]
		==========================

	Buffer 8:
		Size: 9.70MiB
		XLA Label: constant
		Shape: f32[2543309]
		==========================

	Buffer 9:
		Size: 9.70MiB
		XLA Label: constant
		Shape: f32[2543309]
		==========================

	Buffer 10:
		Size: 9.70MiB
		XLA Label: constant
		Shape: s32[2543309,1]
		==========================

	Buffer 11:
		Size: 9.70MiB
		XLA Label: constant
		Shape: s32[2543309,1]
		==========================

	Buffer 12:
		Size: 9.70MiB
		XLA Label: constant
		Shape: s32[2543309,1]
		==========================

	Buffer 13:
		Size: 5.29MiB
		Entry Parameter Subshape: f32[4,347066]
		==========================

	Buffer 14:
		Size: 5.29MiB
		Entry Parameter Subshape: f32[4,347066]
		==========================

	Buffer 15:
		Size: 5.29MiB
		Entry Parameter Subshape: f32[4,347066]
		==========================

so the above is using the vectorised approach which is experimental and it is understandable that it may not work.

But when I use the sequential approach, it works fine but then falls over when it has finished the final chain and tries to draw back the results together, line 603 in mcmc

  File "main.py", line 88, in <module>
    mcmc.run(
  File "/root/.local/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 603, in run
    states, last_state = _laxmap(partial_map_fn, map_args)
  File "/root/.local/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 164, in _laxmap
    return tree_multimap(lambda *args: jnp.stack(args), *ys)
  File "/root/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 180, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/root/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 180, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/root/.local/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 164, in <lambda>
    return tree_multimap(lambda *args: jnp.stack(args), *ys)
  File "/root/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 3407, in stack
    return concatenate(new_arrays, axis=axis)
  File "/root/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 3461, in concatenate
    arrays = [lax.concatenate(arrays[i:i+k], axis)
  File "/root/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 3461, in <listcomp>
    arrays = [lax.concatenate(arrays[i:i+k], axis)
  File "/root/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 557, in concatenate
    return concatenate_p.bind(*operands, dimension=dimension)
  File "/root/.local/lib/python3.8/site-packages/jax/core.py", line 279, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/root/.local/lib/python3.8/site-packages/jax/core.py", line 282, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/root/.local/lib/python3.8/site-packages/jax/core.py", line 598, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/root/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 94, in apply_primitive
    return compiled_fun(*args)
  File "/root/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 114, in <lambda>
    return lambda *args, **kw: compiled(*args, **kw)[0]
  File "/root/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 444, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 5283776000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    4.92GiB
              constant allocation:         0B
        maybe_live_out allocation:    4.92GiB
     preallocated temp allocation:         0B
                 total allocation:    9.84GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 4.92GiB
		Operator: op_name="jit(concatenate)/jit(main)/concatenate[dimension=0]" source_file="/root/.local/lib/python3.8/site-packages/numpyro/infer/mcmc.py" source_line=164
		XLA Label: concatenate
		Shape: f32[4,1000,330236]
		==========================

	Buffer 2:
		Size: 1.23GiB
		Entry Parameter Subshape: f32[1,1000,330236]
		==========================

	Buffer 3:
		Size: 1.23GiB
		Entry Parameter Subshape: f32[1,1000,330236]
		==========================

	Buffer 4:
		Size: 1.23GiB
		Entry Parameter Subshape: f32[1,1000,330236]
		==========================

	Buffer 5:
		Size: 1.23GiB
		Entry Parameter Subshape: f32[1,1000,330236]
		==========================


2022-03-07 09:56:02.127620: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 4.92GiB (rounded to 5283776000)requested by op 2022-03-07 09:56:02.127842: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] ******____****************************************************************************______________ 2022-03-07 09:56:02.127916: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2124] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 5283776000 bytes. BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 4.92GiB constant allocation: 0B maybe_live_out allocation: 4.92GiB preallocated temp allocation: 0B total allocation: 9.84GiB total fragmentation: 0B (0.00%) Peak buffers: Buffer 1: Size: 4.92GiB Operator: op_name="jit(concatenate)/jit(main)/concatenate[dimension=0]" source_file="/root/.local/lib/python3.8/site-packages/numpyro/infer/mcmc.py" source_line=164 XLA Label: concatenate Shape: f32[4,1000,330236] ========================== Buffer 2: Size: 1.23GiB Entry Parameter Subshape: f32[1,1000,330236] ========================== Buffer 3: Size: 1.23GiB Entry Parameter Subshape: f32[1,1000,330236] ========================== Buffer 4: Size: 1.23GiB Entry Parameter Subshape: f32[1,1000,330236] ========================== Buffer 5: Size: 1.23GiB Entry Parameter Subshape: f32[1,1000,330236] ========================== Traceback (most recent call last): File "main.py", line 88, in <module> mcmc.run( File "/root/.local/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 603, in run states, last_state = _laxmap(partial_map_fn, map_args) File "/root/.local/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 164, in _laxmap return tree_multimap(lambda *args: jnp.stack(args), *ys) File "/root/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 180, in tree_map return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) File "/root/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 180, in <genexpr> return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) File "/root/.local/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 164, in <lambda> return tree_multimap(lambda *args: jnp.stack(args), *ys) File "/root/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 3407, in stack return concatenate(new_arrays, axis=axis) File "/root/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 3461, in concatenate arrays = [lax.concatenate(arrays[i:i+k], axis) File "/root/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 3461, in <listcomp> arrays = [lax.concatenate(arrays[i:i+k], axis) File "/root/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 557, in concatenate return concatenate_p.bind(*operands, dimension=dimension) File "/root/.local/lib/python3.8/site-packages/jax/core.py", line 279, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/root/.local/lib/python3.8/site-packages/jax/core.py", line 282, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/root/.local/lib/python3.8/site-packages/jax/core.py", line 598, in process_primitive return primitive.impl(*tracers, **params) File "/root/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 94, in apply_primitive return compiled_fun(*args) File "/root/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 114, in <lambda> return lambda *args, **kw: compiled(*args, **kw)[0] File "/root/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 444, in _execute_compiled out_bufs = compiled.execute(input_bufs) RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 5283776000 bytes. BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 4.92GiB constant allocation: 0B maybe_live_out allocation: 4.92GiB preallocated temp allocation: 0B total allocation: 9.84GiB total fragmentation: 0B (0.00%) Peak buffers: Buffer 1: Size: 4.92GiB Operator: op_name="jit(concatenate)/jit(main)/concatenate[dimension=0]" source_file="/root/.local/lib/python3.8/site-packages/numpyro/infer/mcmc.py" source_line=164 XLA Label: concatenate Shape: f32[4,1000,330236] ========================== Buffer 2: Size: 1.23GiB Entry Parameter Subshape: f32[1,1000,330236] ========================== Buffer 3: Size: 1.23GiB Entry Parameter Subshape: f32[1,1000,330236] ========================== Buffer 4: Size: 1.23GiB Entry Parameter Subshape: f32[1,1000,330236] ========================== Buffer 5: Size: 1.23GiB Entry Parameter Subshape: f32[1,1000,330236] ==========================

Is there any way that this could be done ram instead of the GPU ram or is there a way to specify the model to reduce the GPU ram requirements?

I think you can do

mcmc.run(first_key, ...)
first_chain = jax.device_get(mcmc.get_samples())
mcmc.run(second_key, ...)
second_chain = jax.device_get(mcmc.get_samples())
samples = jax.tree_util.tree_multimap(lambda x, y: np.stack([x, y]), first_chain, second_chain)

okay great, that makes a lot of sense.
I am wondering as well if there is a way to not keep track of certain variables in the same way that we don’t keep track of variables in the warmup phase. Something like a reverse deterministic function. My model is keeping track of a lot of offset variables which I need to fit the model properly but I don’t really want to know about and is chewing up a lot of ram.

having a look at the docs I couldn’t see anything but I thought I would ask, otherwise I can do something like this.
first_chain = jax.device_get({k:v for k, v in mcmc.get_samples().items() if k in variables_i_care_about})

I think your approach is the way to go. We don’t have a filtering option for get_samples.

1 Like

okay so I am still getting issues with GPU memory usage.
I saw that in this thread you recommended using SVI and to move data between the CPU and GPU in batches.

so I am trying this.


guide = autoguide.AutoLowRankMultivariateNormal(logistic_model)
optimizer = numpyro.optim.Adam(0.0005)
svi = SVI(logistic_model, guide, optimizer, Trace_ELBO())
batch_size = 50000 #50k
batch_data = generate_random_batch_data(prepared_data,batch_size=batch_size)

svi_state = svi.init(
    rng_key, 
**batch_data
)

for _ in tqdm(range(1000)):
    batch_data = generate_random_batch_data(prepared_data,batch_size=batch_size)
    svi_state, loss  = svi.update(svi_state, 
    **batch_data
    )
    
params = svi.get_params(svi_state)
param_quantiles = guide.quantiles(params, [0.025, 0.5, 0.975])

but I am still getting memory issues, is there anything that you can recommend to assist with reducing GPU memory usage?

Traceback (most recent call last):
  File "main.py", line 86, in <module>
    svi_state = svi.init(
  File "/root/.local/lib/python3.8/site-packages/numpyro/infer/svi.py", line 180, in init
    guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
  File "/root/.local/lib/python3.8/site-packages/numpyro/handlers.py", line 171, in get_trace
    self(*args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/numpyro/infer/autoguide.py", line 559, in __call__
    latent = self._sample_latent(*args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/numpyro/infer/autoguide.py", line 547, in _sample_latent
    posterior = self._get_posterior()
  File "/root/.local/lib/python3.8/site-packages/numpyro/infer/autoguide.py", line 1035, in _get_posterior
    cov_factor = cov_factor * scale[..., None]
  File "/root/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 6784, in deferring_binary_op
    return binary_op(self, other)
  File "/root/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/jax/_src/api.py", line 427, in cache_miss
    out_flat = xla.xla_call(
  File "/root/.local/lib/python3.8/site-packages/jax/core.py", line 1690, in bind
    return call_bind(self, fun, *args, **params)
  File "/root/.local/lib/python3.8/site-packages/jax/core.py", line 1702, in call_bind
    outs = top_trace.process_call(primitive, fun, tracers, params)
  File "/root/.local/lib/python3.8/site-packages/jax/core.py", line 601, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/root/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 145, in _xla_call_impl
    out = compiled_fun(*args)
  File "/root/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 444, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 13042071472 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   12.15GiB
              constant allocation:         0B
        maybe_live_out allocation:   12.15GiB
     preallocated temp allocation:         0B
                 total allocation:   24.30GiB
              total fragmentation:         0B (0.00%

it is also kinda weird that it uses 12-15GB of RAM when it is running on the CPU which includes the entire dataset but when it runs on the GPU it has a total allocated of 24GB but only has the batch dataset.
I think something must be not set up correctly.