Pyro.plate error with torch.diag()

This is an issue with the torch.daig() interface not supporting broadcasting. I recommend avoiding torch.diag().

  • When x is a vector, torch.diag(x) is a matrix whose diagonal is x.
  • When x is a matrix, torch.diag(x) is a vector which is the diagonal of x.

Thus for vectors torch.diag(torch.diag(x)) == x.