Initialization and PRNGKey

Hi. So I guess I’m confusing some things in the initialization method of this GMM implementation. I wanted to make sure what’s the source of randomness in the handlers.seed and other methods which take a PRNGKey as their input arguments.

def initialize(seed):
    global global_guide
    init_values = {
        "weights": jnp.ones(K) / K,
        "scale": jnp.sqrt(data.var() / 2),
        "locs": data[
                random.PRNGKey(seed), jnp.ones(len(data)) / len(data), shape=(K,)
    global_model = handlers.block(
        handlers.seed(model, random.PRNGKey(seed)),
        hide_fn=lambda site: site["name"]
        not in ["weights", "scale", "locs", "components"],
    global_guide = AutoDelta(
        global_model, init_loc_fn=init_to_value(values=init_values)
    handlers.seed(global_guide, random.PRNGKey(seed))(data)  # warm up the guide
    return elbo.loss(random.PRNGKey(989983), {}, model, global_guide, data)

global_svi_result =
    random.PRNGKey(0), 200 if not smoke_test else 2, data

My question is, apart from the init_values dictionary which initializes the values, what’s the reason of other random.PRNGKey(0) in some other methods? By changing them I see no effect on the loss. As a result, if I have an init_values dictionary which is not random, this type of finding the best seed makes no sense, am I right?

Thanks in advance.

Hi @alregamo

This line was added to warm up the guide by running it forward once and has no effect on the loss. It was needed to prevent the following jax error (I don’t know jax well enough to know what that error is but remember that this line helped to resolve this issue):

UnexpectedTracerError                     Traceback (most recent call last)
<ipython-input-7-b92f2c8281ff> in <cell line: 28>()
     27 # Choose the best among 100 random initializations.
---> 28 loss, seed = min((initialize(seed), seed) for seed in range(100))
     29 initialize(seed)  # initialize the global_guide
     30 print(f"seed = {seed}, initial_loss = {loss}")

22 frames
    [... skipping hidden 8 frame]

    [... skipping hidden 6 frame]

/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/ in _assert_live(self)
   1577   def _assert_live(self) -> None:
   1578     if not self._trace.main.jaxpr_stack:  # type: ignore
-> 1579       raise core.escaped_tracer_error(self, None)
   1581   def get_referent(self):

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was get_importance_log_probs at /usr/local/lib/python3.10/dist-packages/numpyro/infer/ traced for eval_shape.
The leaked intermediate value was created on line /usr/local/lib/python3.10/dist-packages/numpyro/distributions/ (__call__). 
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
/usr/local/lib/python3.10/dist-packages/numpyro/infer/ (_setup_prototype)
/usr/local/lib/python3.10/dist-packages/numpyro/ (__call__)
/usr/local/lib/python3.10/dist-packages/numpyro/infer/ (transform_fn)
/usr/local/lib/python3.10/dist-packages/numpyro/infer/ (<dictcomp>)
/usr/local/lib/python3.10/dist-packages/numpyro/distributions/ (__call__)

PRNGKey here has no effect because all distributions in the global_guide are Delta distributions (because of AutoDelta) which just sample init_values provided and don’t depend on PRNGKey.