Large Memory Requirements

Hi all,

Thanks again for the awesome software/api–it has made prototyping of ideas a breeze.

I’m investigating a simple model with a fairly shallow network and running into memory issues. My data are shaped roughly (700000, 200) and I’m trying to make predictions using a network of size (100, 25)x(25,1) which doesn’t seem so large imo. When using either haiku or stax to implement the network, I run into memory issues during inference such that my task is killed my our HPC management software (using up to 32Gb!). I’m only performing MLE inference and not placing priors on weights in the network. Here is my current model code:

def haiku_encoder(h_dim1, h_dim2):
    return hk.transform(lambda a: hk.Sequential([
        hk.Linear(h_dim1), nn.relu,
        hk.Linear(h_dim2), nn.relu,
        hk.Linear(1), nn.relu,

def model_dnn(y: jnp.ndarray, X: jnp.ndarray = None, h_dim1 = 100, h_dim2 = 25) -> None:

    batch_dim, annot_dim = X.shape
    f_n = haiku_encoder(h_dim1, h_dim2)
    compute_s = haiku_module('s_j', f_n, input_shape=(batch_dim, annot_dim))
    s_j = compute_s(X)
    numpyro.sample('y', dist.Gamma(1/2., 1/(2. * s_j)), obs=y)


Inference seems to bork out when needing to compute log_prob inside the event handler for the model. I’ve looked at hand coding the network and was able to run inference, but ran into other issues I won’t detail here.

Any thoughts into why computing log_prob for a small dataset require so many Gb of memory?


@nmancuso It might be that some log_prob computation is not optimized. Could you add some print statements in the model to get the shapes of s_j and y? If y shape is (70000,) and s_j shape is (70000, 1), then log_prob will be (70000, 70000), which is pretty large IMO.

Btw, not related but svi.init might be a bit faster if you use input_shape=(annot_dim,) in the constructor of haiku_module.

Thanks @fehiepsi. After editing initialization to input_shape=(annot_dim,), you’re correct that the shapes are y.shape == (70000, ) and s_j.shape == (70000,1).

I’m not familiar with the internal workings of log_prob, but shouldn’t this require shape (70000, 1) or (70000,)? All observations are independent given their respective s_j entries.

Oh, that is what I suspected. We should have s_j.shape == (70000,). You can resolve the issue by using s_j = compute_s(X)[..., 0]. :slight_smile:

hk.Linear(1) will give the output with the last dimension have size 1. You can also remove it by defining the network as: lambda a: hk.Sequential([...])(a)[..., 0]

1 Like

Great, thanks for the speedy reply. That seems to have done the trick.