This is an issue with the torch.daig()
interface not supporting broadcasting. I recommend avoiding torch.diag()
.
- When
x
is a vector,torch.diag(x)
is a matrix whose diagonal isx
. - When
x
is a matrix,torch.diag(x)
is a vector which is the diagonal ofx
.
Thus for vectors torch.diag(torch.diag(x)) == x
.