더북(TheBook)

코드 15-1 kNN 알고리즘을 이용한 2차원 점 분류 [ch15/knnplane]

01    #include "opencv2/opencv.hpp"
02    #include <iostream>
03     
04    using namespace cv;
05    using namespace cv::ml;
06    using namespace std;
07     
08    Mat img;
09    Mat train, label;
10    Ptr<KNearest> knn;
11    int k_value = 1;
12     
13    void on_k_changed(int, void*);
14    void addPoint(const Point& pt, int cls);
15    void trainAndDisplay();
16     
17    int main(void)
18    {
19        img = Mat::zeros(Size(500, 500), CV_8UC3);
20        knn = KNearest::create();
21     
22        namedWindow("knn");
23        createTrackbar("k", "knn", &k_value, 5, on_k_changed);
24     
25        const int NUM = 30;
26        Mat rn(NUM, 2, CV_32SC1);
27     
28        randn(rn, 0, 50);
29        for (int i = 0; i < NUM; i++)
30            addPoint(Point(rn.at<int>(i, 0) + 150, rn.at<int>(i, 1) + 150), 0);
31     
32        randn(rn, 0, 50);
33        for (int i = 0; i < NUM; i++)
34            addPoint(Point(rn.at<int>(i, 0) + 350, rn.at<int>(i, 1) + 150), 1);
35     
36        randn(rn, 0, 70);
37        for (int i = 0; i < NUM; i++)
38            addPoint(Point(rn.at<int>(i, 0) + 250, rn.at<int>(i, 1) + 400), 2);
39     
40        trainAndDisplay();
41     
42        waitKey();
43        return 0;
44    }
45     
46    void on_k_changed(int, void*)
47    {
48        if (k_value < 1) k_value = 1;
49        trainAndDisplay();
50    }
51     
52    void addPoint(const Point& pt, int cls)
53    {
54        Mat new_sample = (Mat_<float>(1, 2) << pt.x, pt.y);
55        train.push_back(new_sample);
56     
57        Mat new_label = (Mat_<int>(1, 1) << cls);
58        label.push_back(new_label);
59    }
60     
61    void trainAndDisplay()
62    {
63        knn->train(train, ROW_SAMPLE, label);
64     
65        for (int i = 0; i < img.rows; ++i) {
66            for (int j = 0; j < img.cols; ++j) {
67                Mat sample = (Mat_<float>(1, 2) << j, i);
68     
69                Mat res;
70                knn->findNearest(sample, k_value, res);
71     
72                int response = cvRound(res.at<float>(0, 0));
73                if (response = = 0)
74                    img.at<Vec3b>(i, j) = Vec3b(128, 128, 255); // R
75                else if (response = = 1)
76                    img.at<Vec3b>(i, j) = Vec3b(128, 255, 128); // G
77                else if (response = = 2)
78                    img.at<Vec3b>(i, j) = Vec3b(255, 128, 128); // B
79            }
80        }
81     
82        for (int i = 0; i < train.rows; i++)
83        {
84            int x = cvRound(train.at<float>(i, 0));
85            int y = cvRound(train.at<float>(i, 1));
86            int l = label.at<int>(i, 0);
87     
88            if (l = = 0)
89                circle(img, Point(x, y), 5, Scalar(0, 0, 128), -1, LINE_AA);
90            else if (l = = 1)
91                circle(img, Point(x, y), 5, Scalar(0, 128, 0), -1, LINE_AA);
92            else if (l = = 2)
93                circle(img, Point(x, y), 5, Scalar(128, 0, 0), -1, LINE_AA);
94        }
95     
96        imshow("knn", img);
97    }

신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.