Numpyro memory usage at the beginning of NUTS inference

Hello. I am a former Stan/PyMC3 guy that recently joined the Pyro community. So far I am very happy with the transition, but still getting there. One of the issues I am having a hard time with is the usage of memory at the beginning of NUTS.

I am running a simple linear regression with quite many coefficients (time domain convolutions in matrix form). Basically I am doing reconstruction of sound fields on a volume by projecting the measurements on uniformly distributed sources for further interpolation/extrapolation. The model is

def model_sbl_plate(data):
    obs = data["obs"]
    Nw = data["Nw"]
    Nt = data["Nt"]
    X = data["X"]
    b = data["b"]
    tau0 = data["tau0"]

    sigma = numpyro.sample("sigma", dist.LogNormal(0.0, 1.0))
    dist_var = dist.InverseGamma(1, 10 ** (-b) * jnp.ones(Nw * Nt))
    with numpyro.plate("coefficients", Nw * Nt):
        w_var = jnp.sqrt(numpyro.sample("w_var", dist_var))
        dist_w = dist.Normal(jnp.zeros((Nw * Nt)), tau0 * w_var)
        w = numpyro.sample("w", dist_w)
    mean = X @ w
    y = dist.Normal(mean, sigma)
    with numpyro.plate("observations", len(obs)):
        numpyro.sample("obs", y, obs=obs)

where Nw is the number of sources and Nt the length of the time signals. The size of the dataset is ~10400, and the size of the dictionary X is ~ (10400, 60000). Thus the size of the unknowns w is 60000.

I am running the inferences like this

my_kernel = NUTS(model_sbl_plate, max_tree_depth=7)
posterior = MCMC(my_kernel, num_samples=100, num_warmup=50, num_chains=1), data=data)

What I can see is that at the beginning of the inference the memory grows quite a lot. I am prototyping on my own machine with 32GB of ram, but it is not sufficient, and it goes into swap memory. Once the process starts sampling the 2nd sample, the memory consumption stabilizes to around 15GB.

I am not familiar with what is a large model for numpyro standards, but I would like to know if there is any way of alleviating this. I am trying to understand how to use plates, but I don’t know if they can help out with anything in this case.

For further info, the sampling time is around 100s/it, and in total 3 hours. The priors need to be optimized, but I also wonder if that’s a reasonable computational time for numpyro.

Any help is welcome.


Welcome, @dicano! I didn’t have experience with a model with this high number of parameters, but the high memory at the beginning is unusual to me. This is probably how JAX works but I need to investigate more to make sure that. Could you make a github issue for this. I’ll follow up.

Regarding the sampling time, I guess it is usual if you run this model on CPU. With this high number of parameters, would it be better to use SVI with e.g. AutoNormal guide?

Thanks a lot for the reply :slight_smile: . I am trying to avoid SVI and focus on MCMC for research purposes (we developed a nice research thread on Gaussian processes and I would like to stick to it for now).

In any case, I am debugging it a bit to see what is it. So far I found out that the problem comes when XLA copies the dictionary matrix X precalculated outside numpyro. Once I prepare a proper report I’ll make the github issue such that you can maybe narrow it down easier.

1 Like