모델이 정확하게 예측한 이미지만 출력하기 위한 함수를 정의합니다. 이때 출력 방법에 대한 정의도 함께 진행해 보겠습니다.
코드 6-63 모델이 정확하게 예측한 이미지 출력 함수
def plot_most_correct(correct, classes, n_images, normalize=True):
rows = int(np.sqrt(n_images)) ------ np.sqrt는 제곱근을 계산(0.5를 거듭제곱)
cols = int(np.sqrt(n_images))
fig = plt.figure(figsize=(25,20))
for i in range(rows*cols):
ax = fig.add_subplot(rows, cols, i+1) ------ 출력하려는 그래프 개수만큼 subplot을 만듭니다.
image, true_label, probs = correct[i]
image = image.permute(1, 2, 0) ------ ①
true_prob = probs[true_label]
correct_prob, correct_label = torch.max(probs, dim=0)
true_class = classes[true_label]
correct_class = classes[correct_label]
if normalize: ------ 본래 이미지대로 출력하기 위해 normalize_image 함수 호출
image = normalize_image(image)
ax.imshow(image.cpu().numpy())
ax.set_title(f'true label: {true_class} ({true_prob:.3f})\n' \
f'pred label: {correct_class} ({correct_prob:.3f})')
ax.axis('off')
fig.subplots_adjust(hspace=0.4)