I’d like to implement memoization in a way that’s compatible with Numpyro.
Suppose my model looks like this:
@mem
def bar():
return numpyro.sample('a', dist.Beta(1, 1))
def model():
x = bar()
y = bar()
I’d like to implement mem
so that:
- Within a single sample from this model,
x
and y
always have the same value
- No memoization state is shared across samples
Here’s a gist that illustrates this in a bit more detail.
Since mem
is stateful, I imagine that we’ll need to do something special to make it compatible with Jax.
Any advice on what to do?
@stuhlmueller If you use jit
, you will need to make sure that run_model
will rerun the model with different random seeds (in jax, jit(f)()
will always return the same value). Otherwise, it will use a constant random state each time we run the model. For example,
jitted_model = jax.jit(lambda rng: numpyro.handlers.seed(model, rng)())
def run_jitted(model, rng):
clear_mem()
return onp.array(model(rng))
with numpyro.handlers.seed(rng_seed=0):
for _ in range(10):
rng_key = numpyro.sample("rng_key", dist.PRNGIdentity())
pprint(run_jitted(jitted_model, rng_key))
might work as you expected.
1 Like