Loss function
-
[pytorch] How to use nn.CrossEntropyLoss() 사용법AI 2020. 4. 21. 19:17
아래 코드는 pytorch에서 loss function으로 CrossEntropy를 사용하는 예이다. cls_loss = nn.CrossEntropyLoss() test_pred_y = torch.Tensor([[2,0.1],[0,1]]) # 실제 사용에선 softmax에 의해 각 행의 합이 1이 될 것이다. test_true_y1 = torch.Tensor([1,0]).long() # 1은 true값이 1번째(클래스)라는 것을 의미 test_true_y2 = torch.Tensor([0,1]).long() print(test_pred_y) print(test_true_y1) print(test_true_y2) print(cls_loss(test_pred_y, test_true_y1)) print(cls..