I’m trying to understand the process of training a bayesian neural network. While the examples have been helpful, they don’t seem to provide the full picture.
My understanding is that I would like to have a network with a linear layer whereby the weights and biases are a Normal distribution with mean, scale trained for each connection.
I believe I have created such a network however I’m confused where pyro takes over for the training process. For the obs
pyro parameter don’t I want to pull samples from running multiple times through my network? The doc tutorials seem to indicate creating something like pyro.sample("obs", pyro.distributions.OneHotCategorical(probs=prob), obs=y)
or pyto.distributions.Norma()
etc…
My data has several classes, however they are Normally distributed between 0 and 1. So I am trying to train a network which will predict probabilities for the extremes and boil down the one_hot encoding of the two classes into a single value.
Here is what I have… I feel like the pyro.plate and obs
sampling may need to move, but that’s not what the documentation seems to indicate.
##
## My custom Linear layer
##
class PyroLinear(torch.nn.Linear, pyro.nn.PyroModule): # used as a mixin
def __init__(self, in_features, out_features, device='cpu', **kwargs):
super().__init__(in_features, out_features, **kwargs)
mu = torch.randn_like( self.weight, device=device)
sigma = torch.rand_like( self.weight, device=device) + 0.01
self.weight = pyro.nn.PyroSample( pyro.distributions.Normal( mu, sigma).expand([self.out_features, self.in_features]).to_event(2))
if self.bias is not None:
mu = torch.randn_like( self.bias, device=device)
sigma = torch.rand_like( self.bias, device=device) + 0.01
self.bias = pyro.nn.PyroSample( pyro.distributions.Normal( mu, sigma).expand([self.out_features]).to_event(1))
@property
def device(self):
if len( list(self.parameters())) > 0:
return next(self.parameters()).device
else:
return 'cpu'
#
# Function to create a PyroLinear Layer and add ReLU activation
#
def layer_block( n_in, n_out, device='cpu'):
return pyro.nn.PyroModule[ torch.nn.Sequential](
PyroLinear(n_in, n_out, device=device),
torch.nn.ReLU()
)
#
# My Neural Net Model
#
class NNet_Model(pyro.nn.PyroModule):
def __init__(self, n_inputs=1, n_classes=2, h_layers=[20], device='cpu'):
super().__init__()
self.n_classes = n_classes
self.layer_sizes = [n_inputs, *h_layers]
layer_blocks = [ layer_block(in_f, out_f, device=device) for in_f, out_f in zip(self.layer_sizes[:-1], self.layer_sizes[1:])]
self.feature_net = pyro.nn.PyroModule[ torch.nn.Sequential](*layer_blocks)
self.out = PyroLinear( self.layer_sizes[ -1], n_classes, device=device)
def forward(self, x, y=None):
x = self.feature_net( x)
mu = self.out( x)
prob = torch.softmax( mu, axis=1)
pred = (prob[:,0] - prob[:,1]) * 0.5 + 0.5
#
# Need this for when the model flows through the guide?
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", pyro.distributions.OneHotCategorical(probs=prob), obs=y)
return pred.view( x.shape[0], -1)
def guide(self, x, y=None):
for i in range(0, len( self.feature_net)):
mu = torch.randn_like( self.feature_net[i][0].weight)
sigma = torch.rand_like( self.feature_net[i][0].weight) + 0.1
mu_param = pyro.param( f"feature_net.{i}.0.w_mu", mu)
sigma_param = pyro.param( f"feature_net.{i}.0.w_sigma", sigma, constraint=pyro.distributions.constraints.positive)
_ = pyro.sample( f"feature_net.{i}.0.weight", pyro.distributions.Normal(mu_param, sigma_param).expand([self.feature_net[i][0].weight.size(0), self.feature_net[i][0].weight.size(1)]).to_event(2))
mu = torch.randn_like( self.feature_net[i][0].bias)
sigma = torch.rand_like( self.feature_net[i][0].bias) + 0.1
mu_param = pyro.param( f"feature_net.{i}.0.b_mu", mu)
sigma_param = pyro.param( f"feature_net.{i}.0.b_sigma", sigma, constraint=pyro.distributions.constraints.positive)
_ = pyro.sample( f"feature_net.{i}.0.bias", pyro.distributions.Normal(mu_param, sigma_param).expand([self.feature_net[i][0].bias.size(0)]).to_event(1))
mu = torch.randn_like( self.out.weight)
sigma = torch.rand_like( self.out.weight) + 0.1
mu_param = pyro.param( f"out.w_mu", mu)
sigma_param = pyro.param( f"out.w_sigma", sigma, constraint=pyro.distributions.constraints.positive)
_ = pyro.sample( f"out.weight", pyro.distributions.Normal(mu_param, sigma_param).expand([self.out.weight.size(0), self.out.weight.size(1)]).to_event(2))
mu = torch.randn_like( self.out.bias)
sigma = torch.rand_like( self.out.bias) + 0.1
mu_param = pyro.param( f"out.b_mu", mu)
sigma_param = pyro.param( f"out.b_sigma", sigma, constraint=pyro.distributions.constraints.positive)
_ = pyro.sample( f"out.bias", pyro.distributions.Normal(mu_param, sigma_param).expand([self.out.bias.size(0)]).to_event(1))
@property
def device(self):
return next(self.parameters()).device
def to( self, device):
super().to( device)
for i in range(0, len( self.feature_net)):
self.feature_net[i][0].weight.to( device)
self.feature_net[i][0].bias.to( device)
self.out.weight.to( device)
self.out.bias.to( device)