2차원 평면에서 두 개의 클래스로 구성된 점들을 SVM 알고리즘으로 분류하고, 그 경계면을 화면에 표시하는 예제 프로그램 소스 코드를 코드 15-5에 나타냈습니다. 코드 15-5는 (150, 200), (200, 250), (100, 250), (150, 300) 점들을 0번 클래스로 정의하고, (350, 100), (400, 200), (400, 300), (350, 400) 점들을 1번 클래스로 정의한 후, SVM 알고리즘을 이용하여 이 두 점들을 효과적으로 분리하는 초평면을 구하여 화면에 나타냅니다. 코드 15-5에 나타난 소스 코드 파일은 내려받은 예제 파일 중 ch15/svmplane 프로젝트에서 확인할 수 있습니다.
코드 15-5 SVM 알고리즘을 이용한 2차원 점 분류 [ch15/svmplane]
01 #include "opencv2/opencv.hpp" 02 #include <iostream> 03 04 using namespace cv; 05 using namespace cv::ml; 06 using namespace std; 07 08 int main(void) 09 { 10 Mat train = Mat_<float>({ 8, 2 }, { 11 150, 200, 200, 250, 100, 250, 150, 300, 12 350, 100, 400, 200, 400, 300, 350, 400 }); 13 Mat label = Mat_<int>({ 8, 1 }, { 0, 0, 0, 0, 1, 1, 1, 1 }); 14 15 Ptr<SVM> svm = SVM::create(); 16 svm->setType(SVM::Types::C_SVC); 17 svm->setKernel(SVM::KernelTypes::RBF); 18 svm->trainAuto(train, ROW_SAMPLE, label); 19 20 Mat img = Mat::zeros(Size(500, 500), CV_8UC3); 21 22 for (int j = 0; j < img.rows; j++) { 23 for (int i = 0; i < img.cols; i++) { 24 Mat test = Mat_<float>({ 1, 2 }, { (float)i, (float)j }); 25 int res = cvRound(svm->predict(test)); 26 27 if (res = = 0) 28 img.at<Vec3b>(j, i) = Vec3b(128, 128, 255); // R 29 else 30 img.at<Vec3b>(j, i) = Vec3b(128, 255, 128); // G 31 } 32 } 33 34 for (int i = 0; i < train.rows; i++) { 35 int x = cvRound(train.at<float>(i, 0)); 36 int y = cvRound(train.at<float>(i, 1)); 37 int l = label.at<int>(i, 0); 38 39 if (l = = 0) 40 circle(img, Point(x, y), 5, Scalar(0, 0, 128), -1, LINE_AA); // R 41 else 42 circle(img, Point(x, y), 5, Scalar(0, 128, 0), -1, LINE_AA); // G 43 } 44 45 imshow("svm", img); 46 47 waitKey(); 48 return 0; 49 }