Distribute Parameters across GPU

As the question is asking, is there a way to distribute the parameters of the model across GPUs so that large models can fit in GPU memory? Data sharding can be done using JAX’s sharding method, but can parameters also be sharded?