Hi all,
I’ve tried to run a Sigmoid Belief Network (Neil, 1992) based on the Sparse Gamma Deep Exponential Families example. However, after about 2600 iterations my PC runs out of memory, with the following error:
Traceback (most recent call last):
File "C:\Program Files (x86)\JetBrains\PyCharm 2018.2.4\helpers\pydev\pydevd.py", line 1664, in <module>
main()
File "C:\Program Files (x86)\JetBrains\PyCharm 2018.2.4\helpers\pydev\pydevd.py", line 1658, in main
globals = debugger.run(setup['file'], None, None, is_module)
File "C:\Program Files (x86)\JetBrains\PyCharm 2018.2.4\helpers\pydev\pydevd.py", line 1068, in run
pydev_imports.execfile(file, globals, locals) # execute the script
File "C:\Program Files (x86)\JetBrains\PyCharm 2018.2.4\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "C:/Users/posc8001/Documents/DEF/Scipio_DEF/sigmoid_belief_network.py", line 370, in <module>
model = main(args)
File "C:/Users/posc8001/Documents/DEF/Scipio_DEF/sigmoid_belief_network.py", line 216, in main
loss = svi.step(data)
File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\infer\svi.py", line 99, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\infer\tracegraph_elbo.py", line 225, in loss_and_grads
loss += self._loss_and_grads_particle(weight, model_trace, guide_trace)
File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\infer\tracegraph_elbo.py", line 248, in _loss_and_grads_particle
torch_backward(weight * (surrogate_loss + baseline_loss), retain_graph=self.retain_graph)
File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\pyro\infer\util.py", line 43, in torch_backward
x.backward(retain_graph=retain_graph)
File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\torch\tensor.py", line 102, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "C:\Users\posc8001\.virtualenvs\Scipio_DEF-F7b0vflQ\lib\site-packages\torch\autograd\__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: $ Torch: not enough memory: you tried to allocate 0GB. Buy new RAM! at ..\aten\src\TH\THGeneral.cpp:201
Now I’m wondering if there’s any way to clear the memory between the epoch of the SVI. I couldn’t find an existing example in the forum yet, as gc.collect()
doesn’t do the trick quite yet. I have detached any parameters and loss functions that I store to observe their behavior, so they don’t make a difference.
Any ideas what might help?
Best,
Scipio
Full reference code:
import os
import sys
import argparse
import numpy as np
import torch
from pathlib import Path
from matplotlib import pyplot as plt
import pandas as pd
import torch.utils.data
import torch.optim as torchoptim
import xlsxwriter
import gc
import pyro
from pyro import poutine
import pyro.optim as pyrooptim
from pyro.distributions import Bernoulli, Normal, RelaxedBernoulliStraightThrough
from pyro.contrib.autoguide import AutoDiagonalNormal, AutoGuideList, AutoDiscreteParallel
from pyro.infer import SVI, TraceGraph_ELBO, TraceEnum_ELBO
torch.set_default_tensor_type('torch.FloatTensor')
pyro.enable_validation(True)
pyro.clear_param_store()
def sigmoid(x):
return 1/(1+np.exp(-x))
class SigmoidBeliefDEF(object):
def __init__(self):
# define the sizes of the layers in the deep exponential family
self.top_width = 2
self.bottom_width = 3
self.data_size = 5
# define hyperparameters that control the prior
self.p_z = torch.tensor(0.0)
self.mu_w = torch.tensor(0.0)
self.sigma_w = torch.tensor(3.0)
# define parameters used to initialize variational parameters
self.z_mean_init = 0.0
self.z_sigma_init = 3.0
self.w_mean_init = 0.0
self.w_sigma_init = 5.0
self.softplus = torch.nn.Softplus()
# 1
# define the model
def model(self, x):
x_size = x.size(0)
# 1.1
# sample the global weights
with pyro.plate("w_top_plate", self.top_width * self.bottom_width):
w_top = pyro.sample("w_top", Normal(self.mu_w, self.sigma_w))
with pyro.plate("w_bottom_plate", self.bottom_width * self.data_size):
w_bottom = pyro.sample("w_bottom", Normal(self.mu_w, self.sigma_w))
# 1.2
# sample the local latent random variables
# (the plate encodes the fact that the z's for different data points are conditionally independent)
with pyro.plate("data", size=x_size):
z_top = pyro.sample("z_top", Bernoulli(torch.sigmoid(self.p_z)).expand([self.top_width]).to_event(1))
# note that we need to use matmul (batch matrix multiplication) as well as appropriate reshaping
# to make sure our code is fully vectorized
w_top = w_top.reshape(self.top_width, self.bottom_width) if w_top.dim() == 1 else \
w_top.reshape(-1, self.top_width, self.bottom_width)
mean_bottom = torch.sigmoid(torch.matmul(z_top, w_top))
z_bottom = pyro.sample("z_bottom", Bernoulli(mean_bottom).to_event(1))
w_bottom = w_bottom.reshape(self.bottom_width, self.data_size) if w_bottom.dim() == 1 else \
w_bottom.reshape(-1, self.bottom_width, self.data_size)
mean_obs = torch.sigmoid(torch.matmul(z_bottom, w_bottom))
# observe the data using a Bernoulli likelihood
pyro.sample('obs', Bernoulli(mean_obs).to_event(1), obs=x)
# 2
# define our custom guide a.k.a. variational distribution.
# (note the guide is mean field)
def guide(self, x):
x_size = x.size(0)
# helper for initializing variational parameters
def rand_tensor(shape, mean, sigma):
return mean * torch.ones(shape) + sigma * torch.randn(shape)
# 2.1
# define a helper function to sample z's for a single layer
def sample_zs(name, width, mean=0):
# Sample parameters
if twoParams:
p_z_q = pyro.param("p_z_q_%s" % name,
lambda: rand_tensor((width), self.p_z, self.z_sigma_init))
p_z_q = torch.sigmoid(p_z_q).repeat(x_size).reshape(x_size, width)
else:
p_z_q = pyro.param("p_z_q_%s" % name,
lambda: rand_tensor((x_size, width), self.p_z, self.z_sigma_init))
p_z_q = torch.sigmoid(p_z_q)
# Sample Z's
z = pyro.sample("z_%s" % name, Bernoulli(p_z_q).to_event(1),
infer=dict(baseline={'use_decaying_avg_baseline': True}))
return z
# define a helper function to sample w's for a single layer
def sample_ws(name, width):
# Sample parameters
mean_w_q = pyro.param("mean_w_q_%s" % name,
lambda: rand_tensor(width, self.w_mean_init, self.w_sigma_init))
sigma_w_q = pyro.param("sigma_w_q_%s" % name,
lambda: rand_tensor(width, self.w_mean_init, self.w_sigma_init))
sigma_w_q = self.softplus(sigma_w_q)
# Sample weights
w = pyro.sample("w_%s" % name, Normal(mean_w_q, sigma_w_q))
return(w)
# sample the global weights
with pyro.plate("w_top_plate", self.top_width * self.bottom_width):
w_t = sample_ws("top", self.top_width * self.bottom_width)
with pyro.plate("w_bottom_plate", self.bottom_width * self.data_size):
sample_ws("bottom", self.bottom_width * self.data_size)
# sample the local latent random variables
with pyro.plate("data", x_size):
sample_zs("top", self.top_width)
sample_zs("bottom", self.bottom_width)
def main(args):
dataset_path = Path(r"C:\Users\posc8001\Documents\DEF\Scipio_DEF\data_generation")
file_to_open = dataset_path / "small_data.csv"
f = open(file_to_open)
data = torch.tensor(np.loadtxt(f, delimiter=',')).float()
sigmoid_belief_def = SigmoidBeliefDEF()
opt = pyrooptim.PyroOptim(torchoptim.Adadelta, {})
# Specify parameters of sampling process
n_samp = 20000
guide = sigmoid_belief_def.guide
# Specify Stochastic Variational Inference
svi = SVI(sigmoid_belief_def.model, guide, opt, loss=TraceGraph_ELBO(num_particles=args.eval_particles,
vectorize_particles=True))
if args.store_params:
# the training loop
losses, final_w_bottom = [], []
final_p_z_0 = []
final_w_top = []
final_sig_w_top = []
sample = []
final_sig_w_bottom = []
for i in range(15):
final_w_bottom.append([])
for i in range(15):
final_sig_w_bottom.append([])
for i in range(2):
final_p_z_0.append([])
for i in range(6):
final_w_top.append([])
for i in range(6):
final_sig_w_top.append([])
for k in range(args.num_epochs):
if k % 10 == 0:
gc.collect()
loss = svi.step(data)
if args.store_params:
losses.append(loss)
for i in range(2):
if twoParams:
final_p_z_0[i].append(float(torch.sigmoid(pyro.param("p_z_q_top")[i]).detach().numpy()))
else:
final_p_z_0[i].append(float(torch.sigmoid(pyro.param("p_z_q_top")[:, i].mean()).detach().numpy()))
for i in range(6):
final_w_top[i].append(float(pyro.param("mean_w_q_top")[i].item()))
for i in range(15):
final_w_bottom[i].append(float(pyro.param("mean_w_q_bottom")[i].item()))
if k % args.eval_frequency == 0 and k > 0 or k == args.num_epochs - 1:
print("[epoch %04d] training elbo: %.4g" % (k, loss))
# Plot all parameters and losses
if args.store_params:
plt.plot(losses)
plt.title("ELBO")
plt.xlabel("step")
plt.ylabel("loss");
plt.show()
for i in range(final_p_z_0.__len__()):
plt.plot(final_p_z_0[i][19:(losses.__len__()-1)])
plt.title("P Z_top_" + (i + 1).__str__() + " - Without the first 20 obs")
plt.show()
for i in range(final_w_top.__len__()):
plt.plot(final_w_top[i][119:(losses.__len__()-1)])
plt.title("Mean W_top_" + (i + 1).__str__() + " - Without the first 120 obs")
plt.show()
for i in range(final_w_bottom.__len__()):
plt.plot(final_w_bottom[i][119:(losses.__len__()-1)])
plt.title("Mean W_bottom_" + (i + 1).__str__() + " - Without the first 120 obs")
plt.show()
if __name__ == '__main__':
assert pyro.__version__.startswith('0.3.0')
# parse command line arguments
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', default=10000, type=int, help='number of training epochs')
parser.add_argument('-ef', '--eval-frequency', default=25, type=int,
help='how often to evaluate elbo (number of epochs)')
parser.add_argument('-ep', '--eval-particles', default=200, type=int,
help='number of samples/particles to use during evaluation')
parser.add_argument('--auto-guide', action='store_true', help='whether to use an automatically constructed guide')
parser.add_argument('--create-samples', action='store_true', help='whether to create samples')
parser.add_argument('--store-params', action='store_true', help='whether to store historical loss and '
'parameter values')
args = parser.parse_args()
model = main(args)