더북(TheBook)

caret::createDataPartition(), createFolds(), createMultiFolds()

cvTools를 사용한 교차 검증은 데이터의 속성에 대한 고려 없이 무작위로 데이터를 나눈다. 그러나 좋은 모델 성능 평가가 되려면 예측하고자 하는 분류(Y)에 대한 고려가 필요하다. 예를 들어, 검증 데이터의 Species에 setosa는 너무 많고 versicolor, virginica는 너무 적다면 그 평가가 공정하지 않을 것이기 때문이다.

caret의 createDataPartition( ), createResample( ), createFolds( ), createMultiFolds( ), createTimeSlices( )는 Y 값을 고려한 훈련 데이터와 테스트 데이터의 분리를 지원하며, 이들 함수를 사용해 분리한 데이터는 Y 값의 비율이 원본 데이터와 같게 유지된다.

createDataPartition( )은 데이터를 훈련 데이터와 테스트 데이터로 분할한다. 부트스트래핑5 을 사용한 샘플링은 createResample( )에서 지원된다. 교차 검증을 원한다면 createFolds( ), creaetMultiFolds( )를 사용한다. 가장 기본이 되는 createDataParitition( )과 교차 검증에 대해 살펴보자.

표 9-18 caret을 사용한 교차 검증

caret::createDataPartition : 데이터를 훈련 데이터와 테스트 데이터로 분할한다.

caret::createDataPartition(
  y,          # 분류(또는 레이블)
  times=1,    # 생성할 분할의 수
  p=0.5,      # 훈련 데이터에서 사용할 데이터의 비율
  list=TRUE,  # 결과를 리스트로 반환할지 여부. FALSE면 행렬을 반환한다.
)

반환 값은 훈련 데이터로 사용할 데이터의 색인이다. 반환 값의 데이터 타입은 list 인자로 결정된다.

caret::createFolds : 데이터를 K겹 교차 검증으로 분할한다.

caret::createFolds(
  y,
  k=10,  # K 겹 교차 검증
  list=TRUE,
  # 훈련 데이터 색인을 반환할지 여부. FALSE면 검증 데이터 색인을 반환한다.
  returnTrain=FALSE
)

반환 값은 list, returnTrain에 의해 결정되는 데이터의 색인이다.

caret::createMultiFolds : 데이터를 K겹 교차 검증의 times 반복으로 분할한다.

caret::createMultiFolds(
  y,
  k=10,
  times=5  # 반복 횟수
)

반환 값은 훈련 데이터로 사용할 데이터의 색인이다

다음은 createDataPartition( )을 사용해 아이리스 데이터의 80%를 훈련 데이터, 나머지 20%를 검증 데이터로 분리한 예다.

> library(caret)
> (parts <- createDataPartition(iris$Species, p=0.8))
$Resample1
  [1]   1   2   4   6   7   8   9  11  12  14  15  16
 [13]  17  18  20  21  22  23  24  25  26  27  29  30
 [25]  31  32  34  35  37  38  39  40  41  42  43  46
 [37]  47  48  49  50  51  53  54  55  56  57  58  59
 [49]  61  62  63  64  65  66  67  68  69  71  72  73
 [61]  75  76  77  78  81  82  84  86  87  88  89  90
 [73]  91  92  93  95  97  98  99 100 101 102 104 105
 [85] 106 107 108 109 110 111 112 113 114 117 118 120
 [97] 121 122 123 124 126 128 130 131 132 135 136 137
[109] 138 139 140 141 142 143 144 145 146 147 149 150

> table(iris[parts$Resample1, "Species"])
  setosa versicolor virginica
      40         40        40

예에서 createDataPartition( )은 Species를 고려하여 데이터를 분리하고, 각 Species마다 40개씩을 훈련 데이터로 추출했다. parts$Resample1에 포함되지 않은 행들은 검증 데이터로 사용하면 되며, 다음에서 볼 수 있듯이 검증 데이터에서는 Species마다 10개씩 데이터가 할당된다.

> table(iris[-parts$Resample1, "Species"])
  setosa versicolor virginica
      10         10        10

createFolds( )는 K겹 교차 검증을 지원한다. 다음은 아이리스 데이터를 10겹 교차 검증한 예다. 리스트의 각 요소 [[1]], [[2]], … 등에는 검증 데이터로 사용할 데이터의 색인이 저장되어 있다.

> createFolds(iris$Species, k=10)
$Fold01
 [1]  10  17  32  39  43  54  60  80  88  90 101 114 120 135 139

$Fold02
 [1]   2   5  16  44  47  63  66  77  92  99 103 113 123 140 150

$Fold03
 [1]   8  24  34  40  50  59  73  79  82  86 118 122 141 143 149

$Fold04
 [1]   4  11  14  30  37  51  53  65  74  91 106 107 127 137 148

$Fold05
 [1]   9  13  38  46  48  56  78  81  93 100 108 112 119 126 144

$Fold06
 [1]  12  21  31  33  35  64  72  76  87  95 102 105 115 124 138

$Fold07
 [1]  15  23  26  28  49  52  57  58  67  75 121 128 130 133 142

$Fold08
 [1]   3   7  29  42  45  55  68  84  85  97 109 117 125 146 147

$Fold09
 [1]   1  18  20  22  41  69  71  89  94  98 129 132 134 136 145

$Fold10
 [1]   6  19  25  27  36  61  62  70  83  96 104 110 111 116 131

createMultiFolds( )는 K겹 교차 검증의 times회 반복을 지원한다. 예를 들어, 다음은 아이리스에 대한 10겹 교차 검증의 3회 반복이다.

> createMultiFolds(iris$Species, k=10, times=3)
$Fold01.Rep1
  [1]   1   2   3   4   5   6   7   8   9  10  11  12  13  14  17  18
 [17]  19  21  22  23  24  25  26  27  28  29  31  32  33  34  35  36
 [33]  37  38  39  41  42  43  44  45  46  47  48  49  50  51  52  54
 [49]  56  57  58  59  60  61  62  64  65  66  67  69  70  71  72  73
 [65]  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
 [81]  90  91  92  93  95  96  97  98  99 100 101 102 103 104 105 106
 [97] 107 108 109 110 111 113 114 115 116 117 118 119 120 121 124 125
[113] 126 127 128 129 130 131 132 133 134 135 136 137 139 140 141 142
[129] 143 144 145 146 147 148 149

$Fold02.Rep1
  [1]   2   3   4   5  6   7    8   9  10  11  12  13  15  16  17  18
 [17]  19  20  21  22  23  25  26  27  28  29  30  31  32  33  35  36
 [33]  37  38  39  40  41  42  43  44  45  46  47  48  50  51  52  53
 [49]  54  55  56  57  58  59  60  61  63  64  65  66  68  69  71  72
 [65]  73  74  75  76  77  78  79  80  82  83  84  85  86  87  88  90
 [81]  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 107
 [97] 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
[113] 125 127 128 129 130 131 132 134 135 136 137 138 139 141 142 143
[129] 144 145 146 147 148 149 150

...

$Fold10.Rep3
  [1]   1   2   3   4   5   6   9  10  11  12  13  14  15  16  17  18
 [17]  19  20  21  23  24  25  26  27  28  29  30  31  32  33  34  36
 [33]  37  38  39  40  41  42  44  45  46  47  48  49  50  51  52  53
 [49]  54  55  57  58  60  61  62  63  64  65  66  67  68  69  71  72
 [65]  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88
 [81]  89  91  92  93  94  95  96  98  99 100 102 103 104 105 106 107
 [97] 108 109 110 111 112 113 114 115 116 117 118 120 121 122 123 124
[113] 125 126 127 128 130 132 134 135 136 137 138 139 140 141 142 143
[129] 144 145 146 147 148 149 150

리스트의 각 요소는 ‘Fold 번호.Rep 번호’ 형태로 이름이 붙어 있으며 훈련 데이터로 사용할 데이터의 색인이 저장되어 있다. 실제 사용할 때는 리스트의 각 셀에 부여된 이름과 무관하게 [[i]] 형태로 색인을 가져와서 훈련 데이터와 검증 데이터를 분리하면 된다.

k <- 10
times <- 3
set.seed(137)
cv <- createMultiFolds(iris$Species, k, times)

for (i in 1:times) {
  for (j in 1:k) {
    train_idx <- cv[[i*times + k]]
    iris.train <- iris[train_idx, ]
    iris.validation <- iris[-train_idx, ]
    # 모델링 수행
    ...
    # 평가
    ...
  }
}

5 부트스트래핑(Bootstrapping)은 복원 추출을 사용해 표본의 분포를 추정하는 기법이다. ‘복원’이란 여러 가지 색의 공이 들어 있는 항아리에서 공을 꺼내 그 색을 확인한 뒤 다음 공을 꺼내기 전에 지금 공을 ‘다시 항아리에 넣는’ 표본 추출 방법을 뜻한다. 반면 비복원 추출은 공의 색을 확인하고 공을 항아리에 다시 넣지 않은 채 다음 공을 꺼낸다. 따라서 부트스트래핑에서는 같은 데이터가 여러 번 훈련 데이터로 선택될 수 있다.[18]

신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.