코드 15-1 kNN 알고리즘을 이용한 2차원 점 분류 [ch15/knnplane]
01 #include "opencv2/opencv.hpp" 02 #include <iostream> 03 04 using namespace cv; 05 using namespace cv::ml; 06 using namespace std; 07 08 Mat img; 09 Mat train, label; 10 Ptr<KNearest> knn; 11 int k_value = 1; 12 13 void on_k_changed(int, void*); 14 void addPoint(const Point& pt, int cls); 15 void trainAndDisplay(); 16 17 int main(void) 18 { 19 img = Mat::zeros(Size(500, 500), CV_8UC3); 20 knn = KNearest::create(); 21 22 namedWindow("knn"); 23 createTrackbar("k", "knn", &k_value, 5, on_k_changed); 24 25 const int NUM = 30; 26 Mat rn(NUM, 2, CV_32SC1); 27 28 randn(rn, 0, 50); 29 for (int i = 0; i < NUM; i++) 30 addPoint(Point(rn.at<int>(i, 0) + 150, rn.at<int>(i, 1) + 150), 0); 31 32 randn(rn, 0, 50); 33 for (int i = 0; i < NUM; i++) 34 addPoint(Point(rn.at<int>(i, 0) + 350, rn.at<int>(i, 1) + 150), 1); 35 36 randn(rn, 0, 70); 37 for (int i = 0; i < NUM; i++) 38 addPoint(Point(rn.at<int>(i, 0) + 250, rn.at<int>(i, 1) + 400), 2); 39 40 trainAndDisplay(); 41 42 waitKey(); 43 return 0; 44 } 45 46 void on_k_changed(int, void*) 47 { 48 if (k_value < 1) k_value = 1; 49 trainAndDisplay(); 50 } 51 52 void addPoint(const Point& pt, int cls) 53 { 54 Mat new_sample = (Mat_<float>(1, 2) << pt.x, pt.y); 55 train.push_back(new_sample); 56 57 Mat new_label = (Mat_<int>(1, 1) << cls); 58 label.push_back(new_label); 59 } 60 61 void trainAndDisplay() 62 { 63 knn->train(train, ROW_SAMPLE, label); 64 65 for (int i = 0; i < img.rows; ++i) { 66 for (int j = 0; j < img.cols; ++j) { 67 Mat sample = (Mat_<float>(1, 2) << j, i); 68 69 Mat res; 70 knn->findNearest(sample, k_value, res); 71 72 int response = cvRound(res.at<float>(0, 0)); 73 if (response = = 0) 74 img.at<Vec3b>(i, j) = Vec3b(128, 128, 255); // R 75 else if (response = = 1) 76 img.at<Vec3b>(i, j) = Vec3b(128, 255, 128); // G 77 else if (response = = 2) 78 img.at<Vec3b>(i, j) = Vec3b(255, 128, 128); // B 79 } 80 } 81 82 for (int i = 0; i < train.rows; i++) 83 { 84 int x = cvRound(train.at<float>(i, 0)); 85 int y = cvRound(train.at<float>(i, 1)); 86 int l = label.at<int>(i, 0); 87 88 if (l = = 0) 89 circle(img, Point(x, y), 5, Scalar(0, 0, 128), -1, LINE_AA); 90 else if (l = = 1) 91 circle(img, Point(x, y), 5, Scalar(0, 128, 0), -1, LINE_AA); 92 else if (l = = 2) 93 circle(img, Point(x, y), 5, Scalar(128, 0, 0), -1, LINE_AA); 94 } 95 96 imshow("knn", img); 97 }