rpart의 교차 검증
교차 테스트 데이터에 대해 예측값을 구해보자. create_ten_fold_cv( )는 train에 훈련 데이터를, validation에 검증 데이터를 담은 리스트의 리스트를 반환함을 기억하기 바란다. 다음은 데이터의 대략적인 모습을 보여준다.
List of 10 $ Fold01: List of 2 ..$ train ... ..$ validation ... $ Fold02: List of 2 ..$ train ... ..$ validation ... ... $ Fold10: List of 2 ..$ train ... ..$ validation ...
10개 폴드에 대한 예측값과 실제 값 데이터는 create_ten_fold_cv( )가 반환한 리스트를 순차적으로 방문하면서 validation에 대해 예측을 수행해 구할 수 있다.
> library(rpart) > library(foreach) > folds <- create_ten_fold_cv() > rpart_result <- foreach(f=folds) %do% { + model_rpart <- rpart( + survived ~ pclass + sex + age + sibsp + parch + fare + embarked, + data=f$train) + predicted <- predict(model_rpart, newdata=f$validation, + type="class") + return(list(actual=f$validation$survived, predicted=predicted)) > }
위의 코드에서 foreach는 리스트의 Fold01, Fold02 등을 f라는 변수로 받는다. 그리고 f$train과 f$validation을 사용해 rpart( ), predict( )를 수행한다. 결과는 actual에 생존 여부의 실제 값 그리고 predicted에 생존 여부의 예측값을 저장한 리스트로 반환되며, foreach는 folds 전체에 대한 결과를 또 다시 리스트로 묶는다.
이해를 돕기 위해 아래에 rpart_result의 일부를 보였다.
> head(rpart_result)
[[1]]
[[1]]$actual
[1] dead dead survived survived survived dead
[7] survived survived dead dead survived survived
...
Levels: dead survived
[[1]]$predicted
3 17 22 44 45 85
survived dead survived survived survived dead
86 104 107 111 117 128
...
Levels: dead survived
[[2]]
[[2]]$actual
...