I was wondering if there was any advice about building models that can be run efficiently on GPUs?
I have a model where I replaced for loops with
jax.lax.scan but find that it runs far faster (more than 5 times faster) on my CPU vs GPU. The model performs vector computations with vectors roughly of the size 50x100 - perhaps this is too small to fully make use of the GPU?
Any advice would be much appreciated, as well as other advice as to how to make my model perform as efficiently as possible!