VGG19 모델은 파라미터가 1억 4000만 개로 네트워크를 훈련하는 데 시간이 오래 걸리므로 여기에서는 사전 훈련된 모델에서 가중치를 가져와서 사용하겠습니다. load_weights()를 사용하여 가중치를 가져옵니다. 가중치는 https://www.kaggle.com/keras/vgg19?select=vgg19_weights_tf_dim_ordering_tf_kernels.h5에서 내려받을 수 있습니다.5

    코드 6-17 사전 훈련된 VGG19 가중치 내려받기 및 클래스 정의

    model.load_weights("../chap6/data/vgg19_weights_tf_dim_ordering_tf_kernels.h5") ------ 사전 훈련된 VGG19 모델의 가중치 내려받기
    classes = {282: 'cat',
               681: 'notebook, notebook computer',
               970: 'alp'} ------ 검증용으로 사용될 클래스 세 개만 적용했으며, 전체 이미지에 대한 클래스는 “../chap6/data/”에 위치한 classes.txt 파일을 참고하세요.

    이제 이미지를 모델에 적용한 후 정확한 분류로 예측이 되었는지 검증하는 코드를 작성합니다.

    코드 6-18 이미지 호출 및 예측

    image1 = cv2.imread('../chap6/data/labtop.jpg')
    #image1 = cv2.imread('../chap6/data/starrynight.jpeg')
    #image1 = cv2.imread('../chap6/data/cat.jpg')
    image1 = cv2.resize(image1, (224,224))
    plt.figure()
    plt.imshow(image1)
    image1 = image1[np.newaxis, :] ------ 차원 확장(행을 추가)
    predicted_value = model.predict_classes(image1) ------ ①
    plt.title(classes[predicted_value[0]]) ------ 출력에 대한 title(제목) 지정

     

     


      5 이 파일은 https://www.kaggle.com/keras/vgg19?select=vgg19_weights_tf_dim_ordering_tf_kernels.h5에서 내려받습니다. 길벗출판사의 깃허브(https://github.com/gilbutITbook/080263) 첫 페이지의 URL을 클릭해도 내려받을 수 있습니다. 내려받은 파일은 data 폴더에 넣어 주세요.

    신간 소식 구독하기
    뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.