I am trying to generate a Bayesian Neural Net with Pyro, and feel as though I’m getting close. I can train and predict on a simple test dataset, however after training on my actual data ( aprox 1000 features) I cannot run through the Predictive model due to memory limitations. Even if I take the model off the GPU and attempt using the cpu the process eats up more than 125GiB of ram and will crash.
I suspect I am probably doing something very wrong in the model/guide architecture which might be causing this. Here is what I have, and would appreciate anyone pointing out where I’ve gone wrong…
#
# A basic Linear Layer for Pyro, setup to try and learn the loc, and scale of Normally distributed weights and bias
#
class PyroLinear(torch.nn.Linear, pyro.nn.PyroModule):
def __init__(self, in_features, out_features, device='cpu', **kwargs):
super().__init__(in_features, out_features, **kwargs)
mu = torch.randn_like( self.weight, device=device)
sigma = torch.rand_like( self.weight, device=device) + 0.01
self.weight = pyro.nn.PyroSample( pyro.distributions.Normal( mu, sigma).expand([self.out_features, self.in_features]).to_event(2))
if self.bias is not None:
mu = torch.randn_like( self.bias, device=device)
sigma = torch.rand_like( self.bias, device=device) + 0.01
self.bias = pyro.nn.PyroSample( pyro.distributions.Normal( mu, sigma).expand([self.out_features]).to_event(1))
@property
def device(self):
if len( list(self.parameters())) > 0:
return next(self.parameters()).device
else:
return 'cpu'
#
# Helper function to create a PyroLinear Layer with activation function
#
def layer_block( n_in, n_out, device='cpu'):
return pyro.nn.PyroModule[ torch.nn.Sequential](
PyroLinear(n_in, n_out, device=device),
torch.nn.Tanh()
)
#
# My full model architecture with model/guide
#
class NNet_Model(pyro.nn.PyroModule):
def __init__(self, n_inputs=1, h_layers=[20], scale=0.1, device='cpu'):
super().__init__()
self.predictive = None
self.norm_scale = scale
self.layer_sizes = [n_inputs, *h_layers]
layer_blocks = [ layer_block(in_f, out_f, device=device) for in_f, out_f in zip(self.layer_sizes[:-1], self.layer_sizes[1:])]
self.feature_net = pyro.nn.PyroModule[ torch.nn.Sequential](*layer_blocks)
self.out = PyroLinear( self.layer_sizes[ -1], 2, device=device)
#
# The latent model function used to train with SVI
#
def model_base( self, x, y=None):
x = self.feature_net( x)
mu = self.out( x)
prob = torch.softmax( mu, axis=1)
pred = (prob[:,1] - prob[:,0]) * 0.5 + 0.5
with pyro.plate("data", x.shape[0]):
obs = pyro.sample( "obs", pyro.distributions.Normal( loc=pred, scale=torch.tensor(self.norm_scale).to( self.device)), obs=y).type( torch.float32)
#
# The guide function used in training with SVI
#
def guide(self, x, y=None):
for i in range(0, len( self.feature_net)):
mu = torch.randn_like( self.feature_net[i][0].weight)
sigma = torch.rand_like( self.feature_net[i][0].weight) + 0.1
mu_param = pyro.param( f"feature_net.{i}.0.w_mu", mu)
sigma_param = pyro.param( f"feature_net.{i}.0.w_sigma", sigma, constraint=pyro.distributions.constraints.positive)
_ = pyro.sample( f"feature_net.{i}.0.weight", pyro.distributions.Normal(mu_param, sigma_param).expand([self.feature_net[i][0].weight.size(0), self.feature_net[i][0].weight.size(1)]).to_event(2))
mu = torch.randn_like( self.feature_net[i][0].bias)
sigma = torch.rand_like( self.feature_net[i][0].bias) + 0.1
mu_param = pyro.param( f"feature_net.{i}.0.b_mu", mu)
sigma_param = pyro.param( f"feature_net.{i}.0.b_sigma", sigma, constraint=pyro.distributions.constraints.positive)
_ = pyro.sample( f"feature_net.{i}.0.bias", pyro.distributions.Normal(mu_param, sigma_param).expand([self.feature_net[i][0].bias.size(0)]).to_event(1))
mu = torch.randn_like( self.out.weight)
sigma = torch.rand_like( self.out.weight) + 0.1
mu_param = pyro.param( f"out.w_mu", mu)
sigma_param = pyro.param( f"out.w_sigma", sigma, constraint=pyro.distributions.constraints.positive)
_ = pyro.sample( f"out.weight", pyro.distributions.Normal(mu_param, sigma_param).expand([self.out.weight.size(0), self.out.weight.size(1)]).to_event(2))
mu = torch.randn_like( self.out.bias)
sigma = torch.rand_like( self.out.bias) + 0.1
mu_param = pyro.param( f"out.b_mu", mu)
sigma_param = pyro.param( f"out.b_sigma", sigma, constraint=pyro.distributions.constraints.positive)
_ = pyro.sample( f"out.bias", pyro.distributions.Normal(mu_param, sigma_param).expand([self.out.bias.size(0)]).to_event(1))
# Return the predictions along with the std
# - This creates the Predictive model if it does not yet exist.
# - It then assembles the final mean prediction from "obs" and returns it along with the std
def forward(self, x, samples=100):
if self.predictive == None:
self.predictive = pyro.infer.Predictive(self.model_base, guide=self.guide, num_samples=samples)
result = self.predictive( x)
y_pred = result[ 'obs'].mean( axis=0)
y_std = result[ 'obs'].std( axis=0)
return y_pred.view( x.shape[0], -1), y_std.view( x.shape[0], -1)
@property
def device(self):
return self.out.weight.device
def to( self, device):
super().to( device)
for i in range(0, len( self.feature_net)):
self.feature_net[i][0].weight = self.feature_net[i][0].weight.to( device)
self.feature_net[i][0].bias = self.feature_net[i][0].bias.to( device)
self.out.weight = self.out.weight.to( device)
self.out.bias = self.out.bias.to( device)
I am also rather confused as to why this seems to work (on a small dataset) even though my model_base()
and guide()
do not actually return anything. I guess this is where Pyro does all it’s work under the hood, which makes it difficult to follow what is actually going on.