My model looks something like this which essentially models a generative time-series process.
with pyro.iarange('X_iarange', X.size(1), use_cuda=X.is_cuda):
for t in range(T_max):
T_mask = (t < L).float()
h_t, z_log_mu, z_log_var = self.transition(z_prev, h_prev)
z_dist = dist.Normal(z_log_mu.exp(), z_log_var.exp()).mask(T_mask).independent(1)
z_t = pyro.sample('Z_{}'.format(t + 1), z_dist)
x_log_mu, x_log_var = self.emitter(z_t)
x_dist = dist.Normal(x_log_mu.exp(), x_log_var.exp()).mask(T_mask).independent(1)
pyro.sample('X_{}'.format(t + 1), x_dist, obs=X[t, :, :13])
h_prev = h_t
z_prev = z_t
The training loop looks like this:
loss = 0.0
for (L, X, ihm, los, pheno, decomp) in tqdm(train_data_loader, desc='Minibatch'):
L = torch.from_numpy(L).long().to(args.device)
X = torch.from_numpy(X).float().to(args.device)
pheno = torch.from_numpy(pheno).float().to(args.device)
minibatch_loss = svi.step(L, X, pheno)
loss += minibatch_loss
print(loss / len(train_data_loader.dataset))
This runs for a few epochs and then crashes with the following error
RuntimeError: CUDA out of memory. Tried to allocate 2.50 MiB (GPU 0; 5.94 GiB total capacity; 5.59 GiB already allocated; 2.06 MiB free; 14.84 MiB cached)
I’m running this in a Jupyter notebook right now to quickly play around with values. Even after a while, the GPU memory stays allocated weirdly. Can somebody please help me debug this in Pyro?
Timestamp: Fri Dec 7 17:37:15 2018
Driver Version: 390.77
Number of GPUs: 1
----------------------------------------------------------------------------------------------
# Name Mem. Use Mem. Use Pow. Use Temp.
----------------------------------------------------------------------------------------------
1 GeForce GTX TITAN Black 6081 MiB/6083 MiB 99.97 14.37 W/250.00 W 40 C
No GPUs Driving System Display
----------------------------------------------------------------------------------------------
# GPU PID Process Name Mem. Use
----------------------------------------------------------------------------------------------
0 1 14592 /path/to/python/file 6069 MiB