Note ≡
역주 cross_val_score 함수가 검증에 사용하는 기본 측정 지표는 회귀일 때는 R2, 분류일 때는 정확도입니다. scoring 매개변수를 사용하여 다른 지표로 바꿀 수 있습니다. 사이킷런 0.19 버전에서는 교차 검증에 여러 측정 지표를 사용할 수 있는 cross_validate 함수가 추가되었습니다. 이 함수는 각 폴드에서 훈련과 테스트에 걸린 시간을 반환하고 scoring 매개변수에 지정한 평가 지표마다 훈련 점수와 테스트 점수를 반환합니다. 반환된 딕셔너리에서 훈련 점수와 테스트 점수를 추출하려면 ‘train_XXXX’, ‘test_XXXX’ 형식의 키를 사용하면 됩니다. 앞 코드는 다음과 같이 바꾸어 쓸 수 있습니다.
>>> from sklearn.model_selection import cross_validate
>>> scores = cross_validate(estimator=pipe_lr,
... X=X_train,
... y=y_train,
... scoring=['accuracy'],
... cv=10,
... n_jobs=-1)
>>> print('CV 정확도 점수: %s' % scores['test_accuracy'])
CV 정확도 점수: [0.93478261 0.93478261 0.95652174
0.95652174 0.93478261 0.95555556
0.97777778 0.93333333 0.95555556
0.95555556]
>>> print('CV 정확도: %.3f +/- %.3f' % (np.mean(scores['test_accuracy']),
... np.std(scores['test_accuracy'])))
CV 정확도: 0.950 +/- 0.014
cross_val_predict 함수는 cross_val_score와 비슷한 인터페이스를 제공하지만 훈련 데이터셋의 각 샘플이 테스트 폴드가 되었을 때 만들어진 예측을 반환합니다. 따라서 cross_val_predict 함수의 결과를 사용하여 모델의 성능(예를 들어 정확도)을 계산하면 cross_val_score 함수의 결과와 다르며 바람직한 일반화 성능 추정이 아닙니다. cross_val_predict 함수는 훈련 데이터셋에 대한 예측 결과를 시각화하거나 7장에서 소개하는 스태킹(Stacking) 앙상블(Ensemble) 방법처럼 다른 모델에 주입할 훈련 데이터를 만들기 위해 사용할 수 있습니다.
>>> from sklearn.model_selection import cross_val_predict
>>> preds = cross_val_predict(estimator=pipe_lr,
... X=X_train,
... y=y_train,
... cv=10,v... n_jobs=-1)
>>> preds[:10]
array([0, 0, 0, 0, 0, 0, 0, 1, 1, 1])
method 매개변수에 반환될 값을 계산하기 위한 모델의 메서드를 지정할 수 있습니다. 예를 들어 method= 'predict_proba'로 지정하면 예측 확률을 반환합니다. 'predict', 'predict_proba', 'predict_log_proba', 'decision_function' 등이 가능하며 기본값은 'predict'입니다.
>>> from sklearn.model_selection import cross_val_predict
>>> preds = cross_val_predict(estimator=pipe_lr,
... X=X_train,
... y=y_train,
... cv=10,
... method='predict_proba',
... n_jobs=-1)
>>> preds[:10]
array([[9.93982352e-01, 6.01764759e-03],
[7.64328337e-01, 2.35671663e-01],
[9.72683946e-01, 2.73160539e-02],
[8.41658121e-01, 1.58341879e-01],
[9.97144940e-01, 2.85506043e-03],
[9.99803660e-01, 1.96339882e-04],
[9.99324159e-01, 6.75840609e-04],
[2.12145074e-06, 9.99997879e-01],
[1.28668437e-01, 8.71331563e-01],
[7.76260670e-04, 9.99223739e-01]])