Distribute Parameters across GPU

As the question is asking, is there a way to distribute the parameters of the model across GPUs so that large models can fit in GPU memory? Data sharding can be done using JAX’s sharding method, but can parameters also be sharded?

Did you try Distributed arrays and automatic parallelization — JAX documentation? I think SVI will work with distributed parameters.