Out of memory problem with fori_loop

I wrote a log_prob function that eats substantial memory on each of the iterations it carries out, resulting in out of memory. The function is, roughly as follows (note that probs is a 2D array):

def log_prob(self, probs):
        logp = 0.
        for i in range(len(probs)):
            p = probs[i]
            initialValPp = function spitting out a 1D array based on a transformed p
            out=lax.fori_loop(1, self.max_len_p, self.f_partial, initialValPp)
            pval = out[self.max_len_p + self.counts[i]]
            ln_pval = jnp.log(pval)
            logp += ln_pval
        return logp

Each time the fori_loop executes, my memory declines by about 220MB. The initialValPp array is only about 88KB, though the fori_loop executes thousands of times. Am guessing that each of the executions somehow adds to the memory being consumed.

I only need one value from ‘out’. So I tried extracting a float from ‘out’. That solves the memory problem during the initial mcmc execution that doesn’t try to jit the function (evidently while it’s searching for viable initial values). However, when mcmc execution tries to jit the function (right after the progress bar shows), extracting a float from the traced array gives a ConcretizationTypeError that apparently can’t be avoided in this approach.

I saw some discussion of the type of error I have here on a jax forum that suggested jitting the function as a solution (memory management is supposed to be better under jit). I’m not sure that NumPyro distribution classes can have a jitted log_prob function. When I tried, I get an odd error that tells me that I did not give mcmc two arguments that I clearly did.

Finally, I thought that perhaps converting the outer for loop to a fori_loop might help. I tried a mock up of that in which the outer loop simply kept sending the same info to the inner loop. This still leads to out of memory.

Any suggestions as to what I can try are welcome!

I think you can try to simplify the problem first: write a log_prob function with minimal input and minimal content, then try to jit it to see if it works. For how to implement a jax-friendly function, I would recommend to reach out to jax devs for some help. You might want to replace your for-loops by jax.lax.scan or jax.vmap for speed-up and less-memory-consumed code.

Thank you fehiepsi.

I took my code out of NumPyro. The memory problem disappears. I also went to lengths to reproduce inputs into the problem using numpyro.sample to generate initial parameter estimates. Still, no problem.

I’m guessing that there’s something that NumPyro is doing behind the scenes that results in the big memory usage. Maybe it’s the auto-differentiation. If NumPyro is keeping track of info for backward differentiation, that could explain big memory use. I’m not sure whether NUTS / HMC somehow depends on backward differentiation, but, if not, maybe turning that off would help?

Hmm, not sure if it helps but you can set forward_mode_differentiation=True

Yup, that fixed the memory issue–thank you so much!