I had to distribute training over N GPUs, and the way to do that is define a function within the model itself that takes in a sharded argument, then it’ll work
for example:
def model():
mesh = Mesh(devices, ("batch",))
in_spec=(
P(), # coefficients to replicate across GPUs
P("batch"), # data or indexes to shard
)
out_spec = P("batch")
def calculate_demand():
return output
shard_calc = jax.experimental.shard_map.shard_map(
calculate_demand,
mesh=mesh,
in_specs=in_spec,
out_specs=out_spec
)
demand = shard_calc(*args)
with data_plate:
# Sample observations
numpyro.sample(
"obs",
dist.NegativeBinomial2(
mean=demand ,
concentration= concentration
),
obs=outcome
)