더북(TheBook)

Note ≡


역주 GridSearchCV 클래스와 cross_validate 함수에서 return_train_score 매개변수를 True로 지정하면 훈련 폴드에 대한 점수를 계산하여 반환합니다. 훈련 데이터셋에 대한 점수를 보고 과대적합과 과소적합에 대한 정보를 얻을 수 있지만 실행 시간이 오래 걸릴 수 있습니다. param_range에 여덟 개의 값이 지정되어 있기 때문에 SVC 모델은 'linear' 커널에 대해 여덟 번, 'rbf' 커널에 대해 64번의 교차 검증이 수행됩니다. 따라서 훈련 폴드마다 반환되는 점수는 총 72개입니다. 이 값은 GridSearchCV 클래스의 cv_results_ 딕셔너리 속성에 split{폴드번호}_train_score와 같은 키에 저장되어 있습니다. 예를 들어 첫 번째 폴드의 점수는 'split0_train_score' 키로 저장되어 있습니다.

>>> gs = GridSearchCV(estimator=pipe_svc,
...                   param_grid=param_grid,
...                   scoring='accuracy',
...                   cv=10,
...                   return_train_score=True,
...                   n_jobs=-1)
>>> gs = gs.fit(X_train, y_train)
>>> gs.cv_results_['split0_train_score']
array([0.6405868 ,  0.93643032,  0.97555012,  0.98777506,  0.98533007, 
       0.99266504,  0.99755501,  1.        ,  0.62591687,  0.62591687, 
       0.62591687,  0.62591687,  0.62591687,  0.62591687,  0.62591687, 
       0.62591687,  0.62591687,  0.62591687,  0.62591687,  0.62591687, 
       0.62591687,  0.62591687,  0.62591687,  0.62591687,  0.62591687, 
       0.62591687,  0.62591687,  0.62591687,  0.62591687,  0.62591687, 
       0.62591687,  0.62591687,  0.62591687,  0.7799511 ,  0.94621027, 
       0.96577017,  0.62591687,  0.62591687,  0.62591687,  0.62591687, 
       0.78484108,  0.94621027,  0.9804401 ,  0.99266504,  1.        , 
       1.        ,  1.        ,  1.        ,  0.94621027,  0.97799511, 
       0.99266504,  1.        ,  1.        ,  1.        ,  1.        , 
       1.        ,  0.97799511,  0.98777506,  0.99511002,  1.        , 
       1.        ,  1.        ,  1.        ,  1.        ,  0.98533007, 
       0.99266504,  1.        ,  1.        ,  1.        ,  1.        , 
       1.        ,  1.        ])

전체 훈련 점수의 평균값은 'mean_train_score' 키에 저장되어 있습니다.

>>> gs.cv_results_['mean_train_score']
array([0.6402928 ,  0.93724074,  0.97240801,  0.98510406,  0.98803447,
       0.99145447,  0.99707019,  0.9992677 ,  0.62637307,  0.62637307,
       0.62637307,  0.62637307,  0.62637307,  0.62637307,  0.62637307,
       0.62637307,  0.62637307,  0.62637307,  0.62637307,  0.62637307,
       0.62637307,  0.62637307,  0.62637307,  0.62637307,  0.62637307,
       0.62637307,  0.62637307,  0.62637307,  0.62637307,  0.62637307,
       0.62637307,  0.62637307,  0.62637307,  0.77070249,  0.94700817,
       0.97167094,  0.62637307,  0.62637307,  0.62637307,  0.62637307,
       0.77949371,  0.94725326,  0.97704753,  0.99291848,  1.        ,
       1.        ,  1.        ,  1.        ,  0.94652096,  0.97753354,
       0.99023257,  1.        ,  1.        ,  1.        ,  1.        ,
       1.        ,  0.97680064,  0.98852287,  0.99755799,  1.        ,
       1.        ,  1.        ,  1.        ,  1.        ,  0.98803387,
       0.99291848,  1.        ,  1.        ,  1.        ,  1.        ,
       1.        ,  1.        ])

비슷하게 첫 번째 폴드에 대한 테스트 점수는 'split0_test_score' 키에 저장되어 있습니다.

>>> gs.cv_results_['split0_test_score']
array([0.63043478,  0.89130435,  0.95652174,  0.97826087,  0.95652174,
       0.93478261,  0.95652174,  0.93478261,  0.63043478,  0.63043478,
       0.63043478,  0.63043478,  0.63043478,  0.63043478,  0.63043478,
       0.63043478,  0.63043478,  0.63043478,  0.63043478,  0.63043478,
       0.63043478,  0.63043478,  0.63043478,  0.63043478,  0.63043478,
       0.63043478,  0.63043478,  0.63043478,  0.63043478,  0.63043478,
       0.63043478,  0.63043478,  0.63043478,  0.69565217,  0.93478261,
       0.95652174,  0.63043478,  0.63043478,  0.63043478,  0.63043478,
       0.69565217,  0.93478261,  0.93478261,  1.        ,  0.63043478,
       0.63043478,  0.63043478,  0.63043478,  0.93478261,  0.97826087,
       1.        ,  1.        ,  0.63043478,  0.63043478,  0.63043478,
       0.63043478,  0.97826087,  0.97826087,  0.97826087,  1.        ,
       0.63043478,  0.63043478,  0.63043478,  0.63043478,  0.97826087,
       0.95652174,  0.95652174,  1.        ,  0.63043478,  0.63043478,
       0.63043478,  0.63043478])

GridSearchCV 클래스의 객체에서도 최종 모델의 score, predict, transform 메서드를 바로 호출할 수 있습니다.

>>> print('테스트 정확도: %.3f' % gs.score(X_test, y_test))
테스트 정확도: 0.974
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.