이제 predict_classes() 메서드를 사용하여 결과를 예측해 봅시다. 잘못된 예측은 빨간색으로 표시되고, 올바른 예측은 노란색으로 표시되도록 하겠습니다.

    코드 6-7 이미지 데이터셋 분류에 대한 예측

    class_names = ['cat', 'dog']
    validation, label_batch = next(iter(valid_generator))
    prediction_values = model.predict_classes(validation)
    
    fig = plt.figure(figsize=(12,8))
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
    
    for i in range(8):
        ax = fig.add_subplot(2, 4, i+1, xticks=[], yticks=[])
        ax.imshow(validation[i,:], cmap=plt.cm.gray_r, interpolation='nearest')
    
        if prediction_values[i] == np.argmax(label_batch[i]):
            ax.text(3, 17, class_names[prediction_values[i]], color='yellow', fontsize=14)
        else:
            ax.text(3, 17, class_names[prediction_values[i]], color='red', fontsize=14)
    신간 소식 구독하기
    뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.