Trouble Specifying Hyperpriors within a PyroModule (site not in trace)

Hi Everyone,

I was wondering if you could help me with my issue below. I am preparing a tutorial on Pyro, and wanted to demonstrate a Logistic regression example with hyperpriors for the purpose of introducing how PyroModule works.

However, I can’t seem to get it to behave properly. It seems to ignore the variable ‘lam’ in the trace. Using autoguide also seems to ignore the site. I suspect I am missing something.

I have this model working without using the PyroModule approach , but I would like it both ways for comparison purposes. What am I missing? I have tried specifying ‘lam’ with PyroSample also.

from pyro.distributions import Normal, Bernoulli, HalfCauchy
from pyro.nn import PyroModule, PyroSample
from torch.nn.modules import Linear
from torch.nn import ELU

class Model(pyro.nn.PyroModule):
    def __init__(self, in_dim, w_prior=1.):
        super().__init__()
                        
        self.l1 = PyroModule[Linear](in_dim, 1) 
        self.lam = pyro.sample('lam', HalfCauchy(scale=1.))
        self.l1.weight = PyroSample(Normal(0., self.lam).expand([1, in_dim]).to_event(2))
        self.l1.bias = PyroSample(Normal(0., 10.))
         

    def forward(self, x, y_obs = None):  
        alpha = self.l1(x)   
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", Bernoulli(logits = alpha.reshape(-1)), obs = y_obs)
            
        return obs

you need to use PyroSample and not pyro.sample for lam

Thanks, @martinjankowiak, however it doesn’t work that way either. At the bottom is alternate code (a second way I tried with PyroSample is included).

Then, the response to running

trace = pyro.poutine.trace(LR).get_trace(X)

print(trace.format_shapes())

is still (note that X has shape (200,10))

  Trace Shapes:           
  Param Sites:           
 Sample Sites:           
l1.weight dist     | 1 10
         value     | 1 10
  l1.bias dist   1 |     
         value   1 |     
     data dist     |     
         value 200 |     
      obs dist 200 |     
         value 200 |     

My (updated) model code is here:

from pyro.distributions import Normal, Bernoulli, HalfCauchy
from pyro.nn import PyroModule, PyroSample
from torch.nn.modules import Linear
from torch.nn import ELU

class BayesianLogisticRegression(pyro.nn.PyroModule):
    def __init__(self, in_dim):
        super().__init__()
                 
        self.l1 = PyroModule[Linear](in_dim, 1) 
        
        # lambda ~ HalfCauchy(1)
        self.lam = PyroSample(HalfCauchy(scale=1.).expand([1])) 
        
        # W | lambda ~ N(0, lam**2)
        self.l1.weight = PyroSample(Normal(0., self.lam).expand([1, in_dim]).to_event(2))
        
        # b ~ N(0,100)
        self.l1.bias = PyroSample(Normal(0., 10.).expand([1]))
         
    def forward(self, x, y_obs = None):  
        alpha = self.l1(x)  

        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", Bernoulli(logits = alpha.reshape(-1)), obs = y_obs)
            
        return obs

Hi @Robsal, your second example is almost correct. The issue is that when you write

self.l1.weight = PyroSample(Normal(0., self.lam).expand([1, in_dim]).to_event(2))
#                                  --> XXXXXXXX <-- error here

the self.lam binds early. This is a little subtle, and involves the way Pyro adds effects to PyroSample and PyroParam. The way this works is that Pyro adds effects is by overriding the module’s .__getattr__() method (just as torch.nn.Module does) to apply effects when an attribute is accessed. For example, if you’d like to trace your sample statement

self.lam = PyroSample(HalfCauchy(scale=1.).expand([1]))

then that tracing will happen exactly when the module performs lookup for the attribute self.lam. The problem with your second model is that attribute lookup happens eagerly in your .__init__() method, rather than in your .forward() method. But that’s easy to fix: PyroSample accepts lazy priors so we can make the following small change:

- self.l1.weight = PyroSample(Normal(0., self.lam).expand([1, in_dim]).to_event(2))
+ self.l1.weight = PyroSample(lambda _: Normal(0., self.lam).expand([1, in_dim]).to_event(2))

Now the self.lam lookup doesn’t happen in .__init__(); rather it happens much later in .forward() when you call self.l1(x). That is, calling self.l1(x) triggers the special PyroModule.__getattr__() logic, which triggers your lazy prior evaluation, which triggers another special PyroModule.__getattr__() call to get self.lam.

Thanks for working with this kind of weird framework. It was definitely challenging for us to to bolt on probabilistic programming effects to PyTorch’s existing nn.Module architecture :sweat_smile: . I hope you can get your model working.

Here is a complete working example:

import torch
import pyro
from pyro.distributions import Normal, Bernoulli, HalfCauchy
from pyro.nn import PyroModule, PyroSample
from torch.nn.modules import Linear
from torch.nn import ELU

class BayesianLogisticRegression(pyro.nn.PyroModule):
    def __init__(self, in_dim):
        super().__init__()
        self.l1 = PyroModule[Linear](in_dim, 1)
        self.lam = PyroSample(HalfCauchy(scale=1.).expand([1]))
        self.l1.weight = PyroSample(
            lambda _: Normal(0., self.lam).expand([1, in_dim]).to_event(2)
        )
        self.l1.bias = PyroSample(Normal(0., 10.).expand([1]))
         
    def forward(self, x, y_obs = None):  
        alpha = self.l1(x)  
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", Bernoulli(logits = alpha.reshape(-1)), obs = y_obs)
        return obs

then we can run

X = torch.randn(200, 10)
LR = BayesianLogisticRegression(10)
trace = pyro.poutine.trace(LR).get_trace(X)
print(trace.format_shapes())
 Trace Shapes:           
  Param Sites:           
 Sample Sites:           
      lam dist   1 |     
         value   1 |     
l1.weight dist     | 1 10
         value     | 1 10
  l1.bias dist   1 |     
         value   1 |     
     data dist     |     
         value 200 |     
      obs dist 200 |     
         value 200 |