I am trying to extend the Bayesian regression example for my own purposes, but am running into issues when using the resulting model to predict. Namely the linear component does not seem to be broadcasting, as I get the following when running Predictive
:
RuntimeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py in forward(self, input)
101
102 def forward(self, input: Tensor) -> Tensor:
--> 103 return F.linear(input, self.weight, self.bias)
104
105 def extra_repr(self) -> str:
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
Trace Shapes:
Param Sites:
Sample Sites:
age_kernel.lengthscale dist |
value 1 |
age_kernel.variance dist |
value 1 |
ages dist |
value 37 |
f_tilde_age dist 37 |
value 37 |
f_age dist | 37
value | 37
age_trend.weight dist | 1 1
value 1 | 1 1
age_trend.bias dist | 1
value 1 | 1
The salient parts of my model are these:
class MyModel(PyroModule):
def __init__(self, *args, **kwargs):
...
self.age_trend = PyroModule[torch.nn.Linear](1, 1)
self.age_trend.weight = PyroSample(dist.Normal(dtensor(0.), dtensor(1.)).expand([1, 1]).to_event(2))
self.age_trend.bias = PyroSample(dist.Normal(dtensor(0.), dtensor(10.)).expand([1]).to_event(1))
def forward(self, age_idx, y=None):
...
cov_age = self.age_kernel(torch.arange(A, device=device)).contiguous()
with pyro.plate("ages", A):
f_tilde_age = pyro.sample("f_tilde_age", dist.Normal(dtensor(0.0), dtensor(1.0)))
f_age = pyro.deterministic(
"f_age", torch.linalg.cholesky(cov_age + torch.eye(A, device=device) * jitter) @ f_tilde_age.squeeze()
)
theta = self.age_trend(age_idx.reshape(-1, 1).float()).squeeze() + f_age[..., age_idx]
...
dtensor
is just my convenience function for sending tensors to the appropriate device.
I’m sure I am missing something silly, but can’t think of why it would not broadcast properly as specified.