參考了 Adrien Le Batteux 和 OpenCV 的 kmeans2 範例。
不過寫的落落長,實在不怎麼 simple @@
點集分群結果:
程式碼:
struct FontObject { char str[256]; CvFont font; //字型物件 CvPoint pos; CvScalar color; FontObject (double sw=0.25, double sh=0.25, int thickness = 1) { color = cvScalar (255,255,255); cvInitFont (&font, CV_FONT_HERSHEY_SIMPLEX, sw,sh, 0, thickness, CV_AA); } void show (IplImage* out, int x, int y) { pos.x = x; pos.y = y; sprintf (str, "(%d,%d)", x, y); cvPutText (out, str, pos, &font, color); } void show (IplImage* out, int x, int y, char* txt) { pos.x = x; pos.y = y; cvPutText (out, txt, pos, &font, color); } }; FontObject fontObj (0.75, 0.75, 2); template <typename T> void shuffle (T* a, int n) //簡易洗牌 { while (n > 1) { int j = rand() % n; n--; T t = a[j]; a[j] = a[n]; a[n] = t; } } template <typename T> inline void set_vector (T* out, T* in, int n) { while (n-- > 0) out[n] = in[n]; } template <typename T> inline void add_vector (T* out, T* in, int n) { while (n-- > 0) out[n] += in[n]; } template <typename T> inline void div_vector (T* v, int divisor, int n) { while (n-- > 0) v[n] /= divisor; } template <typename T> inline double cal_distance (T* a, T* b, int n) { double sum = 0, diff; while (n-- > 0) { diff = a[n] - b[n]; sum += diff * diff; } return sqrt (sum); } template <typename T = float> struct Kmeans { int nIteration; //最大迭代次數 int nData; //資料項數 int* label; //資料項的分類標籤 int* ccount; //群心涵蓋的資料個數 T** data; //指向資料集合 T** center; //群心集合 T* new_center; //計算暫用的群心 T* mean; //資料項的幾何中心 int nCenter; //群心個數 int dim; //資料維度 double error; //群心更新時的最大移動量 Kmeans() { srand (time(0)); label = ccount = 0; center = data = 0; new_center = mean = 0; } ~Kmeans() { release(); } private: void set (T* out, T* in) {set_vector (out, in, dim);} void add (T* out, T* in) {add_vector (out, in, dim);} void div (T* v, int divisor) {div_vector (v, divisor, dim);} double distance (T* a, T* b) {return cal_distance (a, b, dim);} void release() //釋放資源 { if (label) del_arr (label); label = 0; if (ccount) del_arr (ccount); ccount = 0; if (center) del_arr (center); center = 0; if (mean) del_arr (mean); mean = 0; if (new_center) del_arr (new_center); new_center = 0; } void init_centers () { if (nCenter < 1) ERROR_MSG_("群心數目需 > 1"); if (nCenter > nData) nCenter = nData; release(); //釋放資源 new_center = (T*) new_arr (sizeof(T), 1, dim); mean = (T*) new_arr (sizeof(T), 1, dim); center = (T**) new_arr (sizeof(T), 2, nCenter, dim); ccount = (int*) new_arr (sizeof(int), 1, nData); label = (int*) new_arr (sizeof(int), 1, nData); int i; for (i=0; i<nCenter; ++i) //隨機指定樣本點為群心 set (center[i], data [rand() % nData]); for (i=0; i<nData; ++i) //計算資料項的幾何中心 add (mean, data[i]); div (mean, nData); } void cluster_assignment () //將樣本分群 { int label_id, j, k; double min_dist, dist; memset (ccount, 0, nData* sizeof(int)); for (j=0; j<nData; ++j) { label_id = 0; min_dist = distance (center[0], data[j]); for (k=1; k<nCenter; ++k) { dist = distance (center[k], data[j]); if (dist < min_dist) { label_id = k; min_dist = dist; } } label[j] = label_id; //歸類 ccount[label_id]++; //統計群心涵蓋的資料個數 } } void update_cluster () //重新估算 centroids 位置 { int j, k, total; double dist = 0; error = 0; for (k=0; k<nCenter; ++k) { memset (new_center, 0, dim* sizeof(T)); total = 0; for (j=0; j<nData; ++j) if (label[j] == k) { add (new_center, data[j]); total++; } if (total > 0) { div (new_center, total); dist = distance (center[k], new_center); if (error < dist) error = dist; set (center[k], new_center); } else set (center[k], mean); } } public: //算法主流程: nCenter = 群心數目, // data = 資料集合 = data [nData_] [dim_] void do_cluster (int nCenter_, double error_threshold, T** data_, int nData_, int dim_) { nCenter = nCenter_; //設定群心數目 data = data_; //指向資料集合 nData = nData_; //設定資料個數 dim = dim_; //設定資料維度 init_centers (); //初始化群心位置 for (int i=0; i<nIteration && error > error_threshold; ++i) { cluster_assignment (); //將樣本分群 update_cluster (); //重新估算 centroids 位置 } } //單步執行 clustering, 回傳群心最大移動量 void setup (int nCenter_, T** data_, int nData_, int dim_) { nCenter = nCenter_; //設定群心數目 data = data_; //指向資料集合 nData = nData_; //設定資料個數 dim = dim_; //設定資料維度 init_centers (); //初始化群心位置 } double iterater_cluster () { cluster_assignment (); //將樣本分群 update_cluster (); //重新估算 centroids 位置 return error; } }; //----------------------------------------------------------------- template <class T> struct KMeanTest { enum {CD = 3, MAX_CLUSTERS = (CD+1)*(CD+1)*(CD+1), //最大群心數 W=700, H=700}; //展示影像的寬高 bool bEnd; CvScalar color_tbl[MAX_CLUSTERS]; IplImage* img; //2D展示影像 Kmeans<T> kmean; T** data; //資料集合 int nCenter; //群心個數 int nData; //資料項數 int dim; //資料維度 private: void release() //釋放資源 { if (data) del_arr (data); data = 0; } void generate_color_table() { int r, g, b, i=0, d = 255/CD; for (r=0; r<256; r+=d) for (g=0; g<256; g+=d) for (b=0; b<256; b+=d) color_tbl[i++] = CV_RGB (r,g,b); shuffle (color_tbl+1, MAX_CLUSTERS-2); } void generate_sample () { release(); data = (T**) new_arr (sizeof(T), 2, nData, dim); for (int i=0; i<nData; ++i) { data[i][0] = rand() % W; data[i][1] = rand() % H; } } public: KMeanTest() { generate_color_table(); cvNamedWindow ("clusters", 1); data = 0; img = cvCreateImage (cvSize (W, H), 8, 3); } ~KMeanTest() { cvDestroyWindow ("clusters"); cvReleaseImage (&img); release(); } void show() { static CvScalar white = CV_RGB (255,255,255); static char txt[9]; int i, key; cvZero (img); for (i=0; i<nData; ++i) //畫出資料項 { CvPoint pt = cvPoint (data[i][0], data[i][1]); CvScalar color = color_tbl [kmean.label[i] + 1]; cvCircle (img, pt, 2, color, CV_FILLED, CV_AA, 0); } for (i=0; i<nCenter; ++i) //畫出群心位置 { T* c = kmean.center[i]; cvCircle (img, cvPoint (c[0],c[1]), 3, white, CV_FILLED, CV_AA, 0); sprintf (txt, "%d", i); //顯示群心標號 fontObj.show (img, c[0]+3, c[1], txt); } cvShowImage ("clusters", img ); key = cvWaitKey(0); if (key == 27 || key == 'q' || key == 'Q') //ESC bEnd = true; } void start (int nCenter_, int nData_, int dim_, double threshold) { bEnd = false; nCenter = nCenter_; nData = nData_; dim = dim_; if (nCenter > MAX_CLUSTERS-2) { cout << "群心數目不得大於 " << MAX_CLUSTERS-2; return; } generate_sample (); kmean.setup (nCenter, data, nData, dim); while (!bEnd && kmean.iterater_cluster() > threshold) show(); } }; int main () { KMeanTest<double> test; test.start(24, 10000, 2, 5); }
沒有留言:
張貼留言