SVI on sharded array of arbitrary shape

I want to fit a model on sharded array to use multi gpus. It works fine when sample size can be evenly divided by number of gpus, but when it is not the case, what can I do to fit the model? Is there a handler to mask elementwise elbo loss?

My current solution is to pad array with 0s or any meaningful number to make sure size can be evenly devided. After initialization, I replace the numbers with nan. Then I use elbo loss with jnp.nansum to get the loss without nan.

Any better idea?

1 Like

I think padding is necessary. Have you tried to use the mask handler?

Yes, I have. I realized I didn’t have to customize elbo loss by using mask, and it works as expected. Thank you!