Tensor shape error using SVGD

I’m encountering an error relating to tensor sizes when attempting to use a fairly simple model containing a Linear module with SVGD. My model is:

    class ErrorModel(pyro.nn.PyroModule):
        def __init__(self, in_feats):
            self.linear = pyro.nn.PyroModule[t.nn.Linear](in_feats, 1)
            self.linear.weight = pyro.nn.PyroSample(dist.Normal(0.,1.).expand([1,in_feats]).to_event(2))
            self.linear.bias = pyro.nn.PyroSample(dist.Normal(0.,5.).expand([1]).to_event(1))
        def forward(self, feats, errs=None):
            u = self.linear(feats).squeeze(-1)
            s = pyro.sample("s", dist.HalfNormal(t.tensor(5.)))
            with pyro.plate(len(errs)) if errs is not None else nullcontext():
                abs_err = pyro.sample("abs_err", dist.LogNormal(u,s),
                                      obs=errs.abs() if errs is not None else None)
            return abs_err * (t.rand(abs_err.shape) - 0.5).sign()

    em = ErrorModel(imed.shape[1])

Now I try to use SVGD:

    opt_svgd = pyro.optim.Adam({'lr': 0.001})
    svgd=pyro.infer.SVGD(em, pyro.infer.RBFSteinKernel(), opt_svgd, 2, 1)

And I get the following error:

~/.local/lib/python3.7/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1608     if input.dim() == 2 and bias is not None:
   1609         # fused op is marginally faster
-> 1610         ret = torch.addmm(bias, input, weight.t())
   1611     else:
   1612         output = input.matmul(weight.t())

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D
     Trace Shapes:           
      Param Sites:           
     Sample Sites:           
        u_max dist 2 1 |     
             value 2 1 |     
linear.weight dist 2 1 | 1 16
             value 2 1 | 1 16
  linear.bias dist 2 1 | 1   
             value 2 1 | 1   
Trace Shapes:
 Param Sites:
Sample Sites:

So what seems to be happening here is the following: linear.weight has a shape of [1, 16] in my model, but somewhere along the line SVGD expands this to [2, 1, 1, 16], to accommodate my 2 particles and max plate nesting of 1 (I guess). But then the Linear module internally tries to transpose the weight matrix, but this fails due to the extra dimensions.

So my question is - am I doing something wrong here or is this a limitation of the SVGD implementation?

this is a limitation of nn.Linear which only knows how to deal with 2D weight matrices. you need a module that can deal with multiple weight matrix dimensions on the left