Making linear model component broadcastable

I am trying to extend the Bayesian regression example for my own purposes, but am running into issues when using the resulting model to predict. Namely the linear component does not seem to be broadcasting, as I get the following when running Predictive:

RuntimeError                              Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py in forward(self, input)
    101 
    102     def forward(self, input: Tensor) -> Tensor:
--> 103         return F.linear(input, self.weight, self.bias)
    104 
    105     def extra_repr(self) -> str:

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
              Trace Shapes:          
               Param Sites:          
              Sample Sites:          
age_kernel.lengthscale dist    |     
                      value  1 |     
   age_kernel.variance dist    |     
                      value  1 |     
                  ages dist    |     
                      value 37 |     
           f_tilde_age dist 37 |     
                      value 37 |     
                 f_age dist    | 37  
                      value    | 37  
      age_trend.weight dist    |  1 1
                      value  1 |  1 1
        age_trend.bias dist    |  1  
                      value  1 |  1 

The salient parts of my model are these:

class MyModel(PyroModule):

    def __init__(self, *args, **kwargs):
       ...

        self.age_trend = PyroModule[torch.nn.Linear](1, 1)
        self.age_trend.weight = PyroSample(dist.Normal(dtensor(0.), dtensor(1.)).expand([1, 1]).to_event(2))
        self.age_trend.bias = PyroSample(dist.Normal(dtensor(0.), dtensor(10.)).expand([1]).to_event(1))

    def forward(self, age_idx, y=None):
        ...
        cov_age = self.age_kernel(torch.arange(A, device=device)).contiguous()

        with pyro.plate("ages", A):
            f_tilde_age = pyro.sample("f_tilde_age", dist.Normal(dtensor(0.0), dtensor(1.0)))

        f_age = pyro.deterministic(
            "f_age", torch.linalg.cholesky(cov_age + torch.eye(A, device=device) * jitter) @ f_tilde_age.squeeze()
        )

        theta = self.age_trend(age_idx.reshape(-1, 1).float()).squeeze() + f_age[..., age_idx]

        ...

dtensor is just my convenience function for sending tensors to the appropriate device.

I’m sure I am missing something silly, but can’t think of why it would not broadcast properly as specified.

nn.Linear is written by neural network people. their default assumption is that neural networks have one set of weights. consequently nn.Linear does not do weight-space broadcasting only input-space broadcasting

1 Like