This may be due to your use of torch.diag(), which does not support broadcasting. See this post for details.
torch.diag()