This is an issue with the torch.daig() interface not supporting broadcasting. I recommend avoiding torch.diag().
- When
xis a vector,torch.diag(x)is a matrix whose diagonal isx. - When
xis a matrix,torch.diag(x)is a vector which is the diagonal ofx.
Thus for vectors torch.diag(torch.diag(x)) == x.