Hello. I am a former Stan/PyMC3 guy that recently joined the Pyro community. So far I am very happy with the transition, but still getting there. One of the issues I am having a hard time with is the usage of memory at the beginning of NUTS.
I am running a simple linear regression with quite many coefficients (time domain convolutions in matrix form). Basically I am doing reconstruction of sound fields on a volume by projecting the measurements on uniformly distributed sources for further interpolation/extrapolation. The model is
def model_sbl_plate(data):
obs = data["obs"]
Nw = data["Nw"]
Nt = data["Nt"]
X = data["X"]
b = data["b"]
tau0 = data["tau0"]
sigma = numpyro.sample("sigma", dist.LogNormal(0.0, 1.0))
dist_var = dist.InverseGamma(1, 10 ** (-b) * jnp.ones(Nw * Nt))
with numpyro.plate("coefficients", Nw * Nt):
w_var = jnp.sqrt(numpyro.sample("w_var", dist_var))
dist_w = dist.Normal(jnp.zeros((Nw * Nt)), tau0 * w_var)
w = numpyro.sample("w", dist_w)
mean = X @ w
y = dist.Normal(mean, sigma)
with numpyro.plate("observations", len(obs)):
numpyro.sample("obs", y, obs=obs)
where Nw is the number of sources and Nt the length of the time signals. The size of the dataset is ~10400, and the size of the dictionary X is ~ (10400, 60000). Thus the size of the unknowns w is 60000.
I am running the inferences like this
my_kernel = NUTS(model_sbl_plate, max_tree_depth=7)
posterior = MCMC(my_kernel, num_samples=100, num_warmup=50, num_chains=1)
posterior.run(rng_key_, data=data)
What I can see is that at the beginning of the inference the memory grows quite a lot. I am prototyping on my own machine with 32GB of ram, but it is not sufficient, and it goes into swap memory. Once the process starts sampling the 2nd sample, the memory consumption stabilizes to around 15GB.
I am not familiar with what is a large model for numpyro standards, but I would like to know if there is any way of alleviating this. I am trying to understand how to use plates, but I don’t know if they can help out with anything in this case.
For further info, the sampling time is around 100s/it, and in total 3 hours. The priors need to be optimized, but I also wonder if that’s a reasonable computational time for numpyro.
Any help is welcome.
Thanks.