Does svi.run() work with sharded arrays?

Does svi.run() support training with sharded arrays?

I’ve tried both sharded and non-sharded input arrays during training but didn’t observe any noticeable difference in training time.

Or should I manually write my training code with init() and update() to benefit from sharding?

For MCMC, the documentation explicitly mentions sharding across multiple devices: Markov Chain Monte Carlo (MCMC) — NumPyro documentation

I’ve faced this exact issue in my workstream!

I had to distribute training over N GPUs, and the way to do that is define a function within the model itself that takes in a sharded argument, then it’ll work

for example:



def model():
    mesh = Mesh(devices, ("batch",))
    in_spec=(
        P(),           # coefficients to replicate across GPUs
        P("batch"),    # data or indexes to shard
    )
    out_spec = P("batch")
    def calculate_demand():
        return output
    shard_calc = jax.experimental.shard_map.shard_map(
        calculate_demand,
        mesh=mesh,
        in_specs=in_spec,
        out_specs=out_spec
    )
    demand = shard_calc(*args)
  
    with data_plate:
        # Sample observations
        numpyro.sample(
            "obs",
            dist.NegativeBinomial2(
                mean=demand ,
                concentration= concentration
            ),
          obs=outcome
      )

this should work