Memory leak for hierarchical mixture model

Hi everyone!

I recently converted myself from PyMC3 to NumPyro to try the promised speedup for NUTS. I achieved to translate my PyMC3 model to NumPyro (version 0.8.0). I reproduced a toy example with the main idea of the model in the code below. The model I use is a hierarchical mixture model, with two parameters of mean and a third mean depending on external data. I struggled with the shapes of the different means, I found a workaround but it does not seem optimal at all to me.

My goal was the following: define two random variables mu (for a mean) and x_ref (for a threshold). Then compute the quantity f(x_ref, x) depending on the random threshold x_ref and given data x. Finally use two means mu and mu + f(x_ref, x) in a mixture model. Note here that mu is supposed to be constant for all data, but mu + f(x_ref, x) is supposed to change for each data point. My workaround artificially expands the dimension of the mean mu to match the shape of f(x_ref, x), but this transformation seems a bit rough to me. I have tried to select the corresponding mean inside a plate statement but I always ended with tensor shapes problems even when indexing my tensor inside a plate statement.

The model runs with NumPyro (around 2-3 times faster than with PyMC3!). However, when I run the inference on a Jupyter Notebook or a celery worker (both only on CPU), the RAM increases a lot (around 300 Mo for the toy example), decreases a bit after the inference but remains significantly higher than its level before the inference. Surprisingly, this problem does not occur when I run a script directly from the bash. I tried different commands I found to manage memory allocation for GPUs (for instance GPU memory allocation — JAX documentation) but none of these commands worked. Any advice or guidance to solve the problem or identify more precisely its source would be greatly appreciated. Thank you!

Here is the code of the model and the inference :

def model(data=None, x_min=None, x_max=None):
    mu_0 = numpyro.sample(
        "mu_0", dist.TruncatedNormal(loc=1, scale=1, low=0))
    mu_1 = numpyro.sample(
        "mu_1", dist.TruncatedNormal(loc=5, scale=1, low=0))


    sigma_1 = numpyro.sample(
        "sigma_1", dist.TruncatedNormal(loc=3, scale=1, low=0))
    sigma_2 = numpyro.sample(
        "sigma_2", dist.TruncatedNormal(loc=5, scale=1, low=0))
    a = numpyro.sample("a", dist.TruncatedNormal(loc=3, scale=1, low=0))
    x_ref = numpyro.sample("x_ref", dist.TwoSidedTruncatedDistribution(
        dist.Normal(loc=20, scale=1), low=18, high=22))

    p = numpyro.sample("p", dist.Dirichlet(jnp.ones(3)))

    intermediate_quantity = f_jax(x_ref=x_ref, x_min=x_min, x_max=x_max)

    with numpyro.plate("data", data.size):
        numpyro.sample(
            "values",
            dist.MixtureSameFamily(
                dist.Categorical(probs=p),
                dist.TruncatedNormal(
                    loc=jnp.stack(
                        [
                            mu_0 * jnp.ones(intermediate_quantity.shape),
                            mu_1 * jnp.ones(intermediate_quantity.shape),
                            mu_1 + a * intermediate_quantity
                        ]
                    ).T,
                    scale=jnp.array(
                        [
                            1,
                            sigma_1,
                            sigma_2
                        ]
                    ),
                    low=0
                )
            ),
            obs=data
        )

rng_key = jrandom.PRNGKey(0)
rng_key, rng_key_ = jrandom.split(rng_key)

size = 1000
x_min = np.random.uniform(-5.0, 15.0, size=size)
x_max = x_min + np.abs(np.random.normal(10, 1, size=size))
data = generate_data(size, x_min, x_max)

kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(
    kernel, num_warmup=2500, num_samples=2500, num_chains=4)
mcmc.run(rng_key_, x_min=x_min, x_max=x_max, data=data)
mcmc.print_summary()

with the intermediate functions:

@jit
def f_jax(x_ref, x_min, x_max):
    filter_sup = x_ref >= x_max
    filter_middle = (x_ref < x_max) & (x_ref >= x_min)
    result = jnp.where(filter_sup, x_ref - (x_max + x_min) / 2, 0)
    result = jnp.where(filter_middle, (x_ref - x_min) *
                   (1 + (x_ref - x_min) / (x_max - x_min)), result)
    return result


def f(x_ref, x_min, x_max):
    filter_sup = x_ref >= x_max
    filter_middle = (x_ref < x_max) & (x_ref >= x_min)
    result = np.where(filter_sup, x_ref - (x_max + x_min) / 2, 0)
    result = np.where(filter_middle, (x_ref - x_min) *
                  (1 + (x_ref - x_min) / (x_max - x_min)), result)
    return result


def generate_data(size, x_min, x_max):
    classes = np.arange(size) % 3

    x_ref = 20

    a = 3

    mus = np.array([1, 5]).reshape(-1, 1) * np.ones((1, size))
    mus = np.vstack(
        [mus, (mus[1, :] + a * f(x_ref, x_min, x_max)).reshape(1, -1)])
    sigmas = np.array([1, 3, 5]).reshape(-1, 1)
    values = stats.truncnorm.rvs(-mus / sigmas,
                             np.inf, loc=mus, scale=sigmas).T
    return values[np.arange(len(classes)), classes]

Hi @yamakishi, jax and numpyro both have the caching mechanism but they should behavior the same for notebook and command line. I’m not sure how to resolve the issue. Maybe this JAX tutorial device memory profiling is helpful.

Hi @fehiepsi,

Thank you for your answer ! I completely forgot to answer about this topic. I did not figure out how to solve my kind of memory leak directly with the tutorial, but I found a workaround: I created a parent script which calls a subscript running the model. It allows to close the subscript at the end of its execution and it solved the memory leak, maybe it could be useful for someone having the same issue.