I wrote a log_prob function that eats substantial memory on each of the iterations it carries out, resulting in out of memory. The function is, roughly as follows (note that probs is a 2D array):
def log_prob(self, probs):
logp = 0.
for i in range(len(probs)):
p = probs[i]
initialValPp = function spitting out a 1D array based on a transformed p
out=lax.fori_loop(1, self.max_len_p, self.f_partial, initialValPp)
pval = out[self.max_len_p + self.counts[i]]
ln_pval = jnp.log(pval)
logp += ln_pval
return logp
Each time the fori_loop executes, my memory declines by about 220MB. The initialValPp array is only about 88KB, though the fori_loop executes thousands of times. Am guessing that each of the executions somehow adds to the memory being consumed.
I only need one value from ‘out’. So I tried extracting a float from ‘out’. That solves the memory problem during the initial mcmc execution that doesn’t try to jit the function (evidently while it’s searching for viable initial values). However, when mcmc execution tries to jit the function (right after the progress bar shows), extracting a float from the traced array gives a ConcretizationTypeError that apparently can’t be avoided in this approach.
I saw some discussion of the type of error I have here on a jax forum that suggested jitting the function as a solution (memory management is supposed to be better under jit). I’m not sure that NumPyro distribution classes can have a jitted log_prob function. When I tried, I get an odd error that tells me that I did not give mcmc two arguments that I clearly did.
Finally, I thought that perhaps converting the outer for loop to a fori_loop might help. I tried a mock up of that in which the outer loop simply kept sending the same info to the inner loop. This still leads to out of memory.
Any suggestions as to what I can try are welcome!