I’d like to implement memoization in a way that’s compatible with Numpyro.
Suppose my model looks like this:
@mem def bar(): return numpyro.sample('a', dist.Beta(1, 1)) def model(): x = bar() y = bar()
I’d like to implement
mem so that:
- Within a single sample from this model,
yalways have the same value
- No memoization state is shared across samples
Here’s a gist that illustrates this in a bit more detail.
mem is stateful, I imagine that we’ll need to do something special to make it compatible with Jax.
Any advice on what to do?