Making use of GPU Acceleration

Hi everybody,

I was wondering if there was any advice about building models that can be run efficiently on GPUs?

I have a model where I replaced for loops with jax.lax.scan but find that it runs far faster (more than 5 times faster) on my CPU vs GPU. The model performs vector computations with vectors roughly of the size 50x100 - perhaps this is too small to fully make use of the GPU?

Any advice would be much appreciated, as well as other advice as to how to make my model perform as efficiently as possible!

Thanks

not knowing anything about your problem: yes tensors of size 50x100 are likely to be too small too see benefits from gpu acceleration