Understanding Neural Net training with Pyro

I’m trying to understand the process of training a bayesian neural network. While the examples have been helpful, they don’t seem to provide the full picture.

My understanding is that I would like to have a network with a linear layer whereby the weights and biases are a Normal distribution with mean, scale trained for each connection.

I believe I have created such a network however I’m confused where pyro takes over for the training process. For the obs pyro parameter don’t I want to pull samples from running multiple times through my network? The doc tutorials seem to indicate creating something like pyro.sample("obs", pyro.distributions.OneHotCategorical(probs=prob), obs=y) or pyto.distributions.Norma() etc…

My data has several classes, however they are Normally distributed between 0 and 1. So I am trying to train a network which will predict probabilities for the extremes and boil down the one_hot encoding of the two classes into a single value.

Here is what I have… I feel like the pyro.plate and obs sampling may need to move, but that’s not what the documentation seems to indicate.

## 
## My custom Linear layer 
##
class PyroLinear(torch.nn.Linear, pyro.nn.PyroModule):  # used as a mixin
    def __init__(self, in_features, out_features, device='cpu', **kwargs):
        super().__init__(in_features, out_features, **kwargs)
        mu = torch.randn_like( self.weight, device=device)
        sigma = torch.rand_like( self.weight, device=device) + 0.01

        self.weight = pyro.nn.PyroSample( pyro.distributions.Normal( mu, sigma).expand([self.out_features, self.in_features]).to_event(2))
        if self.bias is not None:
            mu = torch.randn_like( self.bias, device=device)
            sigma = torch.rand_like( self.bias, device=device) + 0.01
            self.bias = pyro.nn.PyroSample( pyro.distributions.Normal( mu, sigma).expand([self.out_features]).to_event(1))
            
    @property
    def device(self):
        if len( list(self.parameters())) > 0:
            return next(self.parameters()).device
        else:
            return 'cpu'

#
# Function to create a PyroLinear Layer and add ReLU activation
# 
def layer_block( n_in, n_out, device='cpu'):
    return pyro.nn.PyroModule[ torch.nn.Sequential](
        PyroLinear(n_in, n_out, device=device),
        torch.nn.ReLU()
    )

# 
# My Neural Net Model
# 
class NNet_Model(pyro.nn.PyroModule):
    def __init__(self, n_inputs=1, n_classes=2, h_layers=[20], device='cpu'):
        super().__init__()
        self.n_classes = n_classes
        self.layer_sizes = [n_inputs, *h_layers]

        layer_blocks = [ layer_block(in_f, out_f, device=device) for in_f, out_f in zip(self.layer_sizes[:-1], self.layer_sizes[1:])]
        
        self.feature_net = pyro.nn.PyroModule[ torch.nn.Sequential](*layer_blocks)
        self.out = PyroLinear( self.layer_sizes[ -1], n_classes, device=device)

    def forward(self, x, y=None):
        x = self.feature_net( x)
        mu = self.out( x)

        prob = torch.softmax( mu, axis=1)
        pred = (prob[:,0] - prob[:,1]) * 0.5 + 0.5
        #
        # Need this for when the model flows through the guide?
        
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", pyro.distributions.OneHotCategorical(probs=prob), obs=y)

        
        return pred.view( x.shape[0], -1)

        
    def guide(self, x, y=None):
        for i in range(0, len( self.feature_net)):
            mu = torch.randn_like( self.feature_net[i][0].weight)
            sigma = torch.rand_like( self.feature_net[i][0].weight) + 0.1
            mu_param = pyro.param( f"feature_net.{i}.0.w_mu", mu)
            sigma_param = pyro.param( f"feature_net.{i}.0.w_sigma", sigma, constraint=pyro.distributions.constraints.positive)
            _ = pyro.sample( f"feature_net.{i}.0.weight",  pyro.distributions.Normal(mu_param, sigma_param).expand([self.feature_net[i][0].weight.size(0), self.feature_net[i][0].weight.size(1)]).to_event(2))
            mu = torch.randn_like( self.feature_net[i][0].bias)
            sigma = torch.rand_like( self.feature_net[i][0].bias) + 0.1
            mu_param = pyro.param( f"feature_net.{i}.0.b_mu", mu)
            sigma_param = pyro.param( f"feature_net.{i}.0.b_sigma", sigma, constraint=pyro.distributions.constraints.positive)
            _ = pyro.sample( f"feature_net.{i}.0.bias",  pyro.distributions.Normal(mu_param, sigma_param).expand([self.feature_net[i][0].bias.size(0)]).to_event(1))
        mu = torch.randn_like( self.out.weight)
        sigma = torch.rand_like( self.out.weight) + 0.1
        mu_param = pyro.param( f"out.w_mu", mu)
        sigma_param = pyro.param( f"out.w_sigma", sigma, constraint=pyro.distributions.constraints.positive)
        _ = pyro.sample( f"out.weight",  pyro.distributions.Normal(mu_param, sigma_param).expand([self.out.weight.size(0), self.out.weight.size(1)]).to_event(2))
        mu = torch.randn_like( self.out.bias)
        sigma = torch.rand_like( self.out.bias) + 0.1
        mu_param = pyro.param( f"out.b_mu", mu)
        sigma_param = pyro.param( f"out.b_sigma", sigma, constraint=pyro.distributions.constraints.positive)
        _ = pyro.sample( f"out.bias",  pyro.distributions.Normal(mu_param, sigma_param).expand([self.out.bias.size(0)]).to_event(1))
    
    @property
    def device(self):
        return next(self.parameters()).device
    
    def to( self, device):
        super().to( device)
        for i in range(0, len( self.feature_net)):
            self.feature_net[i][0].weight.to( device)
            self.feature_net[i][0].bias.to( device)
        self.out.weight.to( device)
        self.out.bias.to( device)

we generally recommend using tyxe to implement bayesian neural networks on top of pyro instead of doing so in raw pyro

Thank you @martinjankowiak I had not discovered tyxe. I’ve been taking some time to understand it, with not much success yet. The documentation appears to be a bit out of date from the code base (docs suggest methods/classes that no longer exist, or I can’t find). Would you be familiar with any other examples of tyxe out there? or is there more detailed documentation on this package?

1 Like

sorry i don’t have any further info