Hi all,

Thanks again for the awesome software/api–it has made prototyping of ideas a breeze.

I’m investigating a simple model with a fairly shallow network and running into memory issues. My data are shaped roughly `(700000, 200)`

and I’m trying to make predictions using a network of size `(100, 25)x(25,1)`

which doesn’t seem so large imo. When using either haiku or stax to implement the network, I run into memory issues during inference such that my task is killed my our HPC management software (using up to 32Gb!). I’m only performing MLE inference and not placing priors on weights in the network. Here is my current model code:

```
def haiku_encoder(h_dim1, h_dim2):
return hk.transform(lambda a: hk.Sequential([
hk.Linear(h_dim1), nn.relu,
hk.Linear(h_dim2), nn.relu,
hk.Linear(1), nn.relu,
])(a))
def model_dnn(y: jnp.ndarray, X: jnp.ndarray = None, h_dim1 = 100, h_dim2 = 25) -> None:
batch_dim, annot_dim = X.shape
f_n = haiku_encoder(h_dim1, h_dim2)
compute_s = haiku_module('s_j', f_n, input_shape=(batch_dim, annot_dim))
s_j = compute_s(X)
numpyro.sample('y', dist.Gamma(1/2., 1/(2. * s_j)), obs=y)
return
```

Inference seems to bork out when needing to compute `log_prob`

inside the event handler for the model. I’ve looked at hand coding the network and was able to run inference, but ran into other issues I won’t detail here.

Any thoughts into why computing `log_prob`

for a small dataset require so many Gb of memory?

Thanks.