다음은 torch.cat을 실행한 결과입니다.
tensor([[1., 2., 3.], [2., 3., 4.]]) --------------- tensor([[1., 2., 3.], [2., 3., 4.], [4., 5., 6.], [5., 6., 7.]]) --------------- tensor([[1., 2., 3.], [2., 3., 4.], [4., 5., 6.], [5., 6., 7.]]) --------------- tensor([[1., 2., 3., 4., 5., 6.], [2., 3., 4., 5., 6., 7.]])
앞에서 정의한 get_predictions() 함수의 반환값을 각각 images, labels, probs에 저장하여 모델이 정확하게 예측한 이미지를 추출합니다.
코드 6-61 예측 중에서 정확하게 예측한 것을 추출
images, labels, probs = get_predictions(model, test_iterator)
pred_labels = torch.argmax(probs, 1) ------ ①
corrects = torch.eq(labels, pred_labels) ------ 예측과 정답이 같은지 비교
correct_examples = []
for image, label, prob, correct in zip(images, labels, probs, corrects): ------ ②
if correct:
correct_examples.append((image, label, prob))
correct_examples.sort(reverse=True, key=lambda x: torch.max(x[2], dim=0).values) ------ ③