모델이 정확하게 예측한 이미지만 출력하기 위한 함수를 정의합니다. 이때 출력 방법에 대한 정의도 함께 진행해 보겠습니다.

    코드 6-63 모델이 정확하게 예측한 이미지 출력 함수

    def plot_most_correct(correct, classes, n_images, normalize=True):
        rows = int(np.sqrt(n_images)) ------ np.sqrt는 제곱근을 계산(0.5를 거듭제곱)
        cols = int(np.sqrt(n_images))
        fig = plt.figure(figsize=(25,20))
        for i in range(rows*cols):
            ax = fig.add_subplot(rows, cols, i+1) ------ 출력하려는 그래프 개수만큼 subplot을 만듭니다.
            image, true_label, probs = correct[i]
            image = image.permute(1, 2, 0) ------ ①
            true_prob = probs[true_label]
            correct_prob, correct_label = torch.max(probs, dim=0)
            true_class = classes[true_label]
            correct_class = classes[correct_label]
    
            if normalize: ------ 본래 이미지대로 출력하기 위해 normalize_image 함수 호출
                image = normalize_image(image)
    
            ax.imshow(image.cpu().numpy())
            ax.set_title(f'true label: {true_class} ({true_prob:.3f})\n' \
                         f'pred label: {correct_class} ({correct_prob:.3f})')
            ax.axis('off')
    
        fig.subplots_adjust(hspace=0.4)
    신간 소식 구독하기
    뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.