I’m encountering an error relating to tensor sizes when attempting to use a fairly simple model containing a Linear module with SVGD. My model is:

```
class ErrorModel(pyro.nn.PyroModule):
def __init__(self, in_feats):
super().__init__()
self.linear = pyro.nn.PyroModule[t.nn.Linear](in_feats, 1)
self.linear.weight = pyro.nn.PyroSample(dist.Normal(0.,1.).expand([1,in_feats]).to_event(2))
self.linear.bias = pyro.nn.PyroSample(dist.Normal(0.,5.).expand([1]).to_event(1))
def forward(self, feats, errs=None):
u = self.linear(feats).squeeze(-1)
s = pyro.sample("s", dist.HalfNormal(t.tensor(5.)))
with pyro.plate(len(errs)) if errs is not None else nullcontext():
abs_err = pyro.sample("abs_err", dist.LogNormal(u,s),
obs=errs.abs() if errs is not None else None)
return abs_err * (t.rand(abs_err.shape) - 0.5).sign()
em = ErrorModel(imed.shape[1])
```

Now I try to use SVGD:

```
opt_svgd = pyro.optim.Adam({'lr': 0.001})
svgd=pyro.infer.SVGD(em, pyro.infer.RBFSteinKernel(), opt_svgd, 2, 1)
svgd.step()
```

And I get the following error:

```
~/.local/lib/python3.7/site-packages/torch/nn/functional.py in linear(input, weight, bias)
1608 if input.dim() == 2 and bias is not None:
1609 # fused op is marginally faster
-> 1610 ret = torch.addmm(bias, input, weight.t())
1611 else:
1612 output = input.matmul(weight.t())
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D
Trace Shapes:
Param Sites:
Sample Sites:
u_max dist 2 1 |
value 2 1 |
linear.weight dist 2 1 | 1 16
value 2 1 | 1 16
linear.bias dist 2 1 | 1
value 2 1 | 1
Trace Shapes:
Param Sites:
Sample Sites:
```

So what seems to be happening here is the following: linear.weight has a shape of [1, 16] in my model, but somewhere along the line SVGD expands this to [2, 1, 1, 16], to accommodate my 2 particles and max plate nesting of 1 (I guess). But then the Linear module internally tries to transpose the weight matrix, but this fails due to the extra dimensions.

So my question is - am I doing something wrong here or is this a limitation of the SVGD implementation?