Memoization in Numpyro

I’d like to implement memoization in a way that’s compatible with Numpyro.

Suppose my model looks like this:

def bar():
    return numpyro.sample('a', dist.Beta(1, 1))

def model():
    x = bar()
    y = bar()

I’d like to implement mem so that:

  1. Within a single sample from this model, x and y always have the same value
  2. No memoization state is shared across samples

Here’s a gist that illustrates this in a bit more detail.

Since mem is stateful, I imagine that we’ll need to do something special to make it compatible with Jax.

Any advice on what to do?

@stuhlmueller If you use jit, you will need to make sure that run_model will rerun the model with different random seeds (in jax, jit(f)() will always return the same value). Otherwise, it will use a constant random state each time we run the model. For example,

jitted_model = jax.jit(lambda rng: numpyro.handlers.seed(model, rng)())

def run_jitted(model, rng):
    return onp.array(model(rng))

with numpyro.handlers.seed(rng_seed=0):
    for _ in range(10):
        rng_key = numpyro.sample("rng_key", dist.PRNGIdentity())
        pprint(run_jitted(jitted_model, rng_key))

might work as you expected.

1 Like