As the question is asking, is there a way to distribute the parameters of the model across GPUs so that large models can fit in GPU memory? Data sharding can be done using JAX’s sharding method, but can parameters also be sharded?
As the question is asking, is there a way to distribute the parameters of the model across GPUs so that large models can fit in GPU memory? Data sharding can be done using JAX’s sharding method, but can parameters also be sharded?