Main code here
def init_svi(self, X: DeviceArray, *, lr: float, **kwargs):
"""Initialize the SVI state
Args:
X: input data
lr: learning rate
kwargs: other keyword arguments for optimizer
"""
self.optim = self.optim_builder(lr, **kwargs)
self.svi = SVI(self.model, self.guide, self.optim, self.loss)
svi_state = self.svi.init(self.rng_key, X)
if self.svi_state is None:
self.svi_state = svi_state
return self
def _fit(self, X: DeviceArray, n_epochs) -> float:
@jit
def train_epochs(svi_state, n_epochs):
def train_one_epoch(_, val):
loss, svi_state = val
svi_state, loss = self.svi.update(svi_state, X)
return loss, svi_state
return lax.fori_loop(0, n_epochs, train_one_epoch, (0., svi_state))
loss, self.svi_state = train_epochs(self.svi_state, n_epochs)
return float(loss / X.shape[0])
code mainly from https://github.com/FlorianWilhelm/bhm-at-scale/ .
I got nan while trainning
I tried to debug it, but all value I got is like this
how do I debug it ?