Numpyro support for Apple Silicon GPU?

Are there any plans for numpyro to support the use of Apple Silicon GPU’s? I think it would be a very popular feature.

JAX does have (experimental) support Apple GPU’s, see Installing JAX — JAX documentation

Also see the Apple docs here which aren’t 100% up to date, and there’s a jax issue in on that.

Hi @benjamv, if jax supports apple silicon gpu, it’s likely that numpyro code will work.