더북(TheBook)

rpart의 교차 검증

교차 테스트 데이터에 대해 예측값을 구해보자. create_ten_fold_cv( )는 train에 훈련 데이터를, validation에 검증 데이터를 담은 리스트의 리스트를 반환함을 기억하기 바란다. 다음은 데이터의 대략적인 모습을 보여준다.

List of 10
  $ Fold01: List of 2
    ..$ train
      ...
    ..$ validation
      ...
  $ Fold02: List of 2
    ..$ train
      ...
    ..$ validation
      ...

  ...

  $ Fold10: List of 2
    ..$ train
      ...
    ..$ validation
      ...

10개 폴드에 대한 예측값과 실제 값 데이터는 create_ten_fold_cv( )가 반환한 리스트를 순차적으로 방문하면서 validation에 대해 예측을 수행해 구할 수 있다.

> library(rpart)
> library(foreach)
> folds <- create_ten_fold_cv()
> rpart_result <- foreach(f=folds) %do% {
+   model_rpart <- rpart(
+     survived ~ pclass + sex + age + sibsp + parch + fare + embarked,
+     data=f$train)
+   predicted <- predict(model_rpart, newdata=f$validation,
+                        type="class")
+   return(list(actual=f$validation$survived, predicted=predicted))
> }

위의 코드에서 foreach는 리스트의 Fold01, Fold02 등을 f라는 변수로 받는다. 그리고 f$train과 f$validation을 사용해 rpart( ), predict( )를 수행한다. 결과는 actual에 생존 여부의 실제 값 그리고 predicted에 생존 여부의 예측값을 저장한 리스트로 반환되며, foreach는 folds 전체에 대한 결과를 또 다시 리스트로 묶는다.

이해를 돕기 위해 아래에 rpart_result의 일부를 보였다.

> head(rpart_result)
[[1]]
[[1]]$actual
  [1] dead     dead     survived survived survived dead
  [7] survived survived dead     dead     survived survived
  ...
Levels: dead survived

[[1]]$predicted
       3       17       22       44       45       85
survived     dead survived survived survived     dead
      86     104       107      111      117      128
...
Levels: dead survived

[[2]]
[[2]]$actual
...
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.