The problem I am trying to solve is the following:
I have a large amount of data I would like to run SVI against, and for each one I want to run multiple SVI instances with different random seeds. The model and the guide are the same for each set of data and I would like to re-use these to avoid re-compiling.
To handle running a set of data across multiple random seeds (with a nice progress bar) I created the following SVI subclass.
import jax_tqdm
import jax
import numpyro
from functools import partial
class SVI_vec(numpyro.infer.SVI):
def run(
self,
rng_key,
num_chains,
num_steps,
*args,
stable_update=False,
forward_mode_differentiation=False,
init_states=None,
init_params=None,
**kwargs
):
@jax_tqdm.scan_tqdm(num_steps)
def body_fn(svi_state, _):
if stable_update:
svi_state, loss = self.stable_update(
svi_state,
*args,
forward_mode_differentiation=forward_mode_differentiation,
**kwargs,
)
else:
svi_state, loss = self.update(
svi_state,
*args,
forward_mode_differentiation=forward_mode_differentiation,
**kwargs,
)
return svi_state, loss
@jax.vmap
def map_func(i, init_value):
init_bar = jax_tqdm.PBar(id=i, carry=init_value)
final_state, losses = jax.lax.scan(body_fn, init_bar, jax.numpy.arange(num_steps))
return final_state.carry, losses
@partial(jax.vmap, in_axes=(0, None, 0, None))
def vmap_init(rng_key, args, init_params, kwargs):
return self.init(rng_key, *args, init_params=init_params, **kwargs)
rng_keys = jax.random.split(rng_key, num_chains)
if init_states is None:
svi_states = vmap_init(rng_keys, args, init_params, kwargs)
else:
svi_states = init_states
svi_states, losses = map_func(jax.numpy.arange(num_chains), svi_states)
return numpyro.infer.svi.SVIRunResult(self.get_params(svi_states), svi_states, losses)
This works fine for one run, but if I try to run it a seconds time (either passing in new args and kwargs or running it with the exact same inputs) I get the following error:
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[202] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
I have tracked this down to how SVI’s init method uses the numpyro.handlers.seed function. This issue says you can get around these issue by putting the seed wrapper inside a closure, but I can’t seem to get this to work. The various things I have tried have either had the same error message, or forced all the “chains” to initialize at the same location.
If I make a new instance of the SVI class and guide it will again run as expected, but this ends up triggering a re-compile that in my case is about 3 times longer than the actual fit time (30 seconds compile time, 10 seconds fit time). Ideally I would like to pay this compile time cost once up front rather than for each new set of data.
To make it easier here is a self contained reproduction of the issue I am seeing:
import numpyro
import numpyro.distributions as dist
import numpyro.infer as infer
from jax import random
from functools import partial
def model():
a = numpyro.sample('a', dist.Normal(0, 1))
b = numpyro.sample('b', dist.Normal(0, 1))
keys = random.split(random.PRNGKey(0), 2)
optimizer = numpyro.optim.Adam(step_size=.01)
guide = infer.autoguide.AutoDiagonalNormal(model)
svi = infer.SVI(model, guide, optimizer, loss=infer.TraceEnum_ELBO())
@partial(jax.vmap, in_axes=(0, None, None))
def vmap_init(rng_key, args, kwargs):
return svi.init(rng_key, *args, **kwargs)
init_state = vmap_init(keys, (), {}) # runs the first time
init_state = vmap_init(keys, (), {}) # fails the second time