더북(TheBook)

다음은 torch.cat을 실행한 결과입니다.

tensor([[1., 2., 3.],
        [2., 3., 4.]])
---------------
tensor([[1., 2., 3.],
        [2., 3., 4.],
        [4., 5., 6.],
        [5., 6., 7.]])
---------------
tensor([[1., 2., 3.],
        [2., 3., 4.],
        [4., 5., 6.],
        [5., 6., 7.]])
---------------
tensor([[1., 2., 3., 4., 5., 6.],
        [2., 3., 4., 5., 6., 7.]])

앞에서 정의한 get_predictions() 함수의 반환값을 각각 images, labels, probs에 저장하여 모델이 정확하게 예측한 이미지를 추출합니다.

코드 6-61 예측 중에서 정확하게 예측한 것을 추출

images, labels, probs = get_predictions(model, test_iterator)
pred_labels = torch.argmax(probs, 1) ------ ①
corrects = torch.eq(labels, pred_labels) ------ 예측과 정답이 같은지 비교
correct_examples = []

for image, label, prob, correct in zip(images, labels, probs, corrects): ------ ②
    if correct:
        correct_examples.append((image, label, prob))

correct_examples.sort(reverse=True, key=lambda x: torch.max(x[2], dim=0).values) ------ ③
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.