Re-initializing vmap'ed SVI

The problem I am trying to solve is the following:

I have a large amount of data I would like to run SVI against, and for each one I want to run multiple SVI instances with different random seeds. The model and the guide are the same for each set of data and I would like to re-use these to avoid re-compiling.

To handle running a set of data across multiple random seeds (with a nice progress bar) I created the following SVI subclass.

import jax_tqdm
import jax
import numpyro
from functools import partial


class SVI_vec(numpyro.infer.SVI):
    def run(
        self,
        rng_key,
        num_chains,
        num_steps,
        *args,
        stable_update=False,
        forward_mode_differentiation=False,
        init_states=None,
        init_params=None,
        **kwargs
    ):
        @jax_tqdm.scan_tqdm(num_steps)
        def body_fn(svi_state, _):
            if stable_update:
                svi_state, loss = self.stable_update(
                    svi_state,
                    *args,
                    forward_mode_differentiation=forward_mode_differentiation,
                    **kwargs,
                )
            else:
                svi_state, loss = self.update(
                    svi_state,
                    *args,
                    forward_mode_differentiation=forward_mode_differentiation,
                    **kwargs,
                )
            return svi_state, loss

        @jax.vmap
        def map_func(i, init_value):
            init_bar = jax_tqdm.PBar(id=i, carry=init_value)
            final_state, losses = jax.lax.scan(body_fn, init_bar, jax.numpy.arange(num_steps))
            return final_state.carry, losses

        @partial(jax.vmap, in_axes=(0, None, 0, None))
        def vmap_init(rng_key, args, init_params, kwargs):
            return self.init(rng_key, *args, init_params=init_params, **kwargs)

        rng_keys = jax.random.split(rng_key, num_chains)
        if init_states is None:
            svi_states = vmap_init(rng_keys, args, init_params, kwargs)
        else:
            svi_states = init_states

        svi_states, losses = map_func(jax.numpy.arange(num_chains), svi_states)
        return numpyro.infer.svi.SVIRunResult(self.get_params(svi_states), svi_states, losses)

This works fine for one run, but if I try to run it a seconds time (either passing in new args and kwargs or running it with the exact same inputs) I get the following error:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[202] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.

I have tracked this down to how SVI’s init method uses the numpyro.handlers.seed function. This issue says you can get around these issue by putting the seed wrapper inside a closure, but I can’t seem to get this to work. The various things I have tried have either had the same error message, or forced all the “chains” to initialize at the same location.

If I make a new instance of the SVI class and guide it will again run as expected, but this ends up triggering a re-compile that in my case is about 3 times longer than the actual fit time (30 seconds compile time, 10 seconds fit time). Ideally I would like to pay this compile time cost once up front rather than for each new set of data.

To make it easier here is a self contained reproduction of the issue I am seeing:

import numpyro
import numpyro.distributions as dist
import numpyro.infer as infer
from jax import random
from functools import partial


def model():
    a = numpyro.sample('a', dist.Normal(0, 1))
    b = numpyro.sample('b', dist.Normal(0, 1))


keys = random.split(random.PRNGKey(0), 2)
optimizer = numpyro.optim.Adam(step_size=.01)
guide = infer.autoguide.AutoDiagonalNormal(model)
svi = infer.SVI(model, guide, optimizer, loss=infer.TraceEnum_ELBO())


@partial(jax.vmap, in_axes=(0, None, None))
def vmap_init(rng_key, args, kwargs):
    return svi.init(rng_key, *args, **kwargs)


init_state = vmap_init(keys, (), {}) # runs the first time
init_state = vmap_init(keys, (), {}) # fails the second time