Parallelising Numpyro

I think you can make a wrapper like run_inference and leverage joblib to distribute the tasks to each CPU.

On each CPU, if your datasets have the same shape, you can use jit_model_args=True. In case your datasets has different shapes and there is no local latent variable (e.g. the model in the previous thread works because it only contains global latent variables), you can pad the data and use mask to mask out the padded part in the model. This way, jit_model_args will work.

You can optimize further by grouping datasets with similar shapes to each task, e.g. if you have datasets A, B, C, D whose shapes are 7, 5, 3, 9, then you can run B, C on 1 CPU and A, D on 1 CPU.

1 Like