rpart의 교차 검증

    교차 테스트 데이터에 대해 예측값을 구해보자. create_ten_fold_cv( )는 train에 훈련 데이터를, validation에 검증 데이터를 담은 리스트의 리스트를 반환함을 기억하기 바란다. 다음은 데이터의 대략적인 모습을 보여준다.

    List of 10
      $ Fold01: List of 2
        ..$ train
          ...
        ..$ validation
          ...
      $ Fold02: List of 2
        ..$ train
          ...
        ..$ validation
          ...
    
      ...
    
      $ Fold10: List of 2
        ..$ train
          ...
        ..$ validation
          ...
    

    10개 폴드에 대한 예측값과 실제 값 데이터는 create_ten_fold_cv( )가 반환한 리스트를 순차적으로 방문하면서 validation에 대해 예측을 수행해 구할 수 있다.

    > library(rpart)
    > library(foreach)
    > folds <- create_ten_fold_cv()
    > rpart_result <- foreach(f=folds) %do% {
    +   model_rpart <- rpart(
    +     survived ~ pclass + sex + age + sibsp + parch + fare + embarked,
    +     data=f$train)
    +   predicted <- predict(model_rpart, newdata=f$validation,
    +                        type="class")
    +   return(list(actual=f$validation$survived, predicted=predicted))
    > }
    

    위의 코드에서 foreach는 리스트의 Fold01, Fold02 등을 f라는 변수로 받는다. 그리고 f$train과 f$validation을 사용해 rpart( ), predict( )를 수행한다. 결과는 actual에 생존 여부의 실제 값 그리고 predicted에 생존 여부의 예측값을 저장한 리스트로 반환되며, foreach는 folds 전체에 대한 결과를 또 다시 리스트로 묶는다.

    이해를 돕기 위해 아래에 rpart_result의 일부를 보였다.

    > head(rpart_result)
    1
    1$actual
      [1] dead     dead     survived survived survived dead
      [7] survived survived dead     dead     survived survived
      ...
    Levels: dead survived
    
    1$predicted
           3       17       22       44       45       85
    survived     dead survived survived survived     dead
          86     104       107      111      117      128
    ...
    Levels: dead survived
    
    2
    2$actual
    ...
    
    신간 소식 구독하기
    뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.