Hi everyone!
I recently converted myself from PyMC3 to NumPyro to try the promised speedup for NUTS. I achieved to translate my PyMC3 model to NumPyro (version 0.8.0). I reproduced a toy example with the main idea of the model in the code below. The model I use is a hierarchical mixture model, with two parameters of mean and a third mean depending on external data. I struggled with the shapes of the different means, I found a workaround but it does not seem optimal at all to me.
My goal was the following: define two random variables mu (for a mean) and x_ref (for a threshold). Then compute the quantity f(x_ref, x) depending on the random threshold x_ref and given data x. Finally use two means mu and mu + f(x_ref, x) in a mixture model. Note here that mu is supposed to be constant for all data, but mu + f(x_ref, x) is supposed to change for each data point. My workaround artificially expands the dimension of the mean mu to match the shape of f(x_ref, x), but this transformation seems a bit rough to me. I have tried to select the corresponding mean inside a plate
statement but I always ended with tensor shapes problems even when indexing my tensor inside a plate
statement.
The model runs with NumPyro (around 2-3 times faster than with PyMC3!). However, when I run the inference on a Jupyter Notebook or a celery worker (both only on CPU), the RAM increases a lot (around 300 Mo for the toy example), decreases a bit after the inference but remains significantly higher than its level before the inference. Surprisingly, this problem does not occur when I run a script directly from the bash. I tried different commands I found to manage memory allocation for GPUs (for instance GPU memory allocation ā JAX documentation) but none of these commands worked. Any advice or guidance to solve the problem or identify more precisely its source would be greatly appreciated. Thank you!
Here is the code of the model and the inference :
def model(data=None, x_min=None, x_max=None):
mu_0 = numpyro.sample(
"mu_0", dist.TruncatedNormal(loc=1, scale=1, low=0))
mu_1 = numpyro.sample(
"mu_1", dist.TruncatedNormal(loc=5, scale=1, low=0))
sigma_1 = numpyro.sample(
"sigma_1", dist.TruncatedNormal(loc=3, scale=1, low=0))
sigma_2 = numpyro.sample(
"sigma_2", dist.TruncatedNormal(loc=5, scale=1, low=0))
a = numpyro.sample("a", dist.TruncatedNormal(loc=3, scale=1, low=0))
x_ref = numpyro.sample("x_ref", dist.TwoSidedTruncatedDistribution(
dist.Normal(loc=20, scale=1), low=18, high=22))
p = numpyro.sample("p", dist.Dirichlet(jnp.ones(3)))
intermediate_quantity = f_jax(x_ref=x_ref, x_min=x_min, x_max=x_max)
with numpyro.plate("data", data.size):
numpyro.sample(
"values",
dist.MixtureSameFamily(
dist.Categorical(probs=p),
dist.TruncatedNormal(
loc=jnp.stack(
[
mu_0 * jnp.ones(intermediate_quantity.shape),
mu_1 * jnp.ones(intermediate_quantity.shape),
mu_1 + a * intermediate_quantity
]
).T,
scale=jnp.array(
[
1,
sigma_1,
sigma_2
]
),
low=0
)
),
obs=data
)
rng_key = jrandom.PRNGKey(0)
rng_key, rng_key_ = jrandom.split(rng_key)
size = 1000
x_min = np.random.uniform(-5.0, 15.0, size=size)
x_max = x_min + np.abs(np.random.normal(10, 1, size=size))
data = generate_data(size, x_min, x_max)
kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(
kernel, num_warmup=2500, num_samples=2500, num_chains=4)
mcmc.run(rng_key_, x_min=x_min, x_max=x_max, data=data)
mcmc.print_summary()
with the intermediate functions:
@jit
def f_jax(x_ref, x_min, x_max):
filter_sup = x_ref >= x_max
filter_middle = (x_ref < x_max) & (x_ref >= x_min)
result = jnp.where(filter_sup, x_ref - (x_max + x_min) / 2, 0)
result = jnp.where(filter_middle, (x_ref - x_min) *
(1 + (x_ref - x_min) / (x_max - x_min)), result)
return result
def f(x_ref, x_min, x_max):
filter_sup = x_ref >= x_max
filter_middle = (x_ref < x_max) & (x_ref >= x_min)
result = np.where(filter_sup, x_ref - (x_max + x_min) / 2, 0)
result = np.where(filter_middle, (x_ref - x_min) *
(1 + (x_ref - x_min) / (x_max - x_min)), result)
return result
def generate_data(size, x_min, x_max):
classes = np.arange(size) % 3
x_ref = 20
a = 3
mus = np.array([1, 5]).reshape(-1, 1) * np.ones((1, size))
mus = np.vstack(
[mus, (mus[1, :] + a * f(x_ref, x_min, x_max)).reshape(1, -1)])
sigmas = np.array([1, 3, 5]).reshape(-1, 1)
values = stats.truncnorm.rvs(-mus / sigmas,
np.inf, loc=mus, scale=sigmas).T
return values[np.arange(len(classes)), classes]