2009年12月9日 星期三

Simple K-Means

 
參考了 Adrien Le Batteux 和 OpenCV 的 kmeans2 範例。
不過寫的落落長,實在不怎麼 simple @@

點集分群結果:


程式碼:
struct FontObject 
{
    char     str[256];
    CvFont   font;                              //字型物件
    CvPoint  pos;
    CvScalar color;

    FontObject (double sw=0.25, double sh=0.25, int thickness = 1)
    {
        color = cvScalar (255,255,255);
        cvInitFont (&font, CV_FONT_HERSHEY_SIMPLEX,
            sw,sh, 0, thickness, CV_AA);
    }
    void show (IplImage* out, int x, int y)
    {
        pos.x = x;
        pos.y = y;
        sprintf (str, "(%d,%d)", x, y);
        cvPutText (out, str, pos, &font, color);
    }
    void show (IplImage* out, int x, int y, char* txt)
    {
        pos.x = x;
        pos.y = y;
        cvPutText (out, txt, pos, &font, color);
    }
};

FontObject fontObj (0.75, 0.75, 2);


template <typename T>
void shuffle (T* a, int n)        //簡易洗牌
{
    while (n > 1) {
        int j = rand() % n; n--;
        T t = a[j]; a[j] = a[n]; a[n] = t;
    }
}

template <typename T>
inline void set_vector (T* out, T* in, int n)
{
    while (n-- > 0) out[n] = in[n]; 
}

template <typename T>
inline void add_vector (T* out, T* in, int n)
{
    while (n-- > 0) out[n] += in[n]; 
}

template <typename T>
inline void div_vector (T* v, int divisor, int n)
{
    while (n-- > 0) v[n] /= divisor; 
}

template <typename T>
inline double cal_distance (T* a, T* b, int n)
{
    double sum = 0, diff;
    while (n-- > 0) {
        diff = a[n] - b[n];
        sum += diff * diff;
    }
    return sqrt (sum);
}

template <typename T = float>
struct Kmeans
{
    int    nIteration;              //最大迭代次數
    int    nData;                   //資料項數
    int*   label;                   //資料項的分類標籤
    int*   ccount;                  //群心涵蓋的資料個數
    T**    data;                    //指向資料集合
    T**    center;                  //群心集合 
    T*     new_center;              //計算暫用的群心 
    T*     mean;                    //資料項的幾何中心 
    int    nCenter;                 //群心個數 
    int    dim;                     //資料維度
    double error;                   //群心更新時的最大移動量
    
    Kmeans() {
        srand (time(0)); 
        label  = ccount = 0;
        center = data   = 0;
        new_center = mean = 0; 
    }
   ~Kmeans() {
        release();
    }

private:

    void set (T* out, T* in)      {set_vector (out, in, dim);}
    void add (T* out, T* in)      {add_vector (out, in, dim);}
    void div (T* v, int divisor)  {div_vector (v, divisor, dim);}
    double distance (T* a, T* b)  {return cal_distance (a, b, dim);}

    void release()                          //釋放資源
    {
        if (label)      del_arr (label);      label  = 0;
        if (ccount)     del_arr (ccount);     ccount = 0;
        if (center)     del_arr (center);     center = 0;
        if (mean)       del_arr (mean);       mean   = 0;
        if (new_center) del_arr (new_center); new_center = 0;
    }

    void init_centers ()
    {
        if (nCenter < 1)     ERROR_MSG_("群心數目需 > 1");
        if (nCenter > nData) nCenter = nData; 

        release();                          //釋放資源
        
        new_center = (T*) new_arr (sizeof(T), 1, dim);
        mean   = (T*)   new_arr (sizeof(T),   1, dim);
        center = (T**)  new_arr (sizeof(T),   2, nCenter, dim);
        ccount = (int*) new_arr (sizeof(int), 1, nData);
        label  = (int*) new_arr (sizeof(int), 1, nData);

        int i;

        for (i=0; i<nCenter; ++i)           //隨機指定樣本點為群心
            set (center[i], data [rand() % nData]);
        
        for (i=0; i<nData; ++i)             //計算資料項的幾何中心
            add (mean, data[i]);
        div (mean, nData);
        
    }

    void cluster_assignment ()              //將樣本分群
    {
        int     label_id, j, k;
        double  min_dist, dist;
        
        memset (ccount, 0, nData* sizeof(int));

        for (j=0; j<nData; ++j) {
            label_id = 0;
            min_dist = distance (center[0], data[j]);
            for (k=1; k<nCenter; ++k) {
                dist = distance (center[k], data[j]);
                if (dist < min_dist) {
                    label_id = k;
                    min_dist = dist;
                }
            }
            label[j] = label_id;            //歸類
            ccount[label_id]++;             //統計群心涵蓋的資料個數 
        }
    }

    void update_cluster ()                  //重新估算 centroids 位置
    {
        int j, k, total;
        double dist = 0;
        error = 0;

        for (k=0; k<nCenter; ++k)           
        {
            memset (new_center, 0, dim* sizeof(T));
            total = 0;
            for (j=0; j<nData; ++j) 
                if (label[j] == k) {
                    add (new_center, data[j]);
                    total++;
                }      
            if (total > 0) {           
                div (new_center, total); 
                dist = distance (center[k], new_center);
                if (error < dist) 
                    error = dist;
                set (center[k], new_center);
            }
            else            
                set (center[k], mean);            
        }        
    }

public:

    //算法主流程: nCenter = 群心數目, 
    //            data    = 資料集合 = data [nData_] [dim_] 

    void do_cluster 
        (int nCenter_, double error_threshold, 
         T** data_, int nData_, int dim_)
    {
        nCenter = nCenter_;                 //設定群心數目
        data    = data_;                    //指向資料集合
        nData   = nData_;                   //設定資料個數
        dim     = dim_;                     //設定資料維度
        
        init_centers ();                    //初始化群心位置
             
        for (int i=0; i<nIteration && error > error_threshold; ++i) {
            cluster_assignment ();          //將樣本分群            
            update_cluster ();              //重新估算 centroids 位置
        }
    }


    //單步執行 clustering, 回傳群心最大移動量

    void setup (int nCenter_, T** data_, int nData_, int dim_)
    {
        nCenter = nCenter_;                 //設定群心數目
        data    = data_;                    //指向資料集合
        nData   = nData_;                   //設定資料個數
        dim     = dim_;                     //設定資料維度        
        init_centers ();                    //初始化群心位置
    }
    double iterater_cluster ()
    {
        cluster_assignment ();              //將樣本分群            
        update_cluster ();                  //重新估算 centroids 位置
        return error;
    }
};



//-----------------------------------------------------------------

template <class T>
struct KMeanTest
{
    enum {CD = 3, 
          MAX_CLUSTERS = (CD+1)*(CD+1)*(CD+1),  //最大群心數
          W=700, H=700};                        //展示影像的寬高
    bool bEnd;

    CvScalar color_tbl[MAX_CLUSTERS];
    IplImage* img;                          //2D展示影像
    Kmeans<T> kmean;
    T** data;                               //資料集合
    int nCenter;                            //群心個數
    int nData;                              //資料項數
    int dim;                                //資料維度

private:
    void release()                          //釋放資源
    {
        if (data) del_arr (data); data = 0;
    }
    void generate_color_table()
    {
        int r, g, b, i=0, d = 255/CD;
        for (r=0; r<256; r+=d)
            for (g=0; g<256; g+=d)
                for (b=0; b<256; b+=d)
                    color_tbl[i++] = CV_RGB (r,g,b);
        shuffle (color_tbl+1, MAX_CLUSTERS-2);
    }
    void generate_sample ()
    {      
        release();
        data = (T**) new_arr (sizeof(T), 2, nData, dim);
        for (int i=0; i<nData; ++i) {
            data[i][0] = rand() % W;
            data[i][1] = rand() % H;
        }
    }

public:
    KMeanTest()
    {
        generate_color_table();
        cvNamedWindow ("clusters", 1);
        data = 0;
        img = cvCreateImage (cvSize (W, H), 8, 3);
    }
   ~KMeanTest()
    {
        cvDestroyWindow ("clusters");
        cvReleaseImage (&img);
        release();
    }
    void show()
    {
        static CvScalar white = CV_RGB (255,255,255);
        static char txt[9];
        int i, key;
        
        cvZero (img);
        for (i=0; i<nData; ++i)                 //畫出資料項
        {
            CvPoint  pt    = cvPoint (data[i][0], data[i][1]);            
            CvScalar color = color_tbl [kmean.label[i] + 1];
            cvCircle (img, pt, 2, color, CV_FILLED, CV_AA, 0);
        }
        for (i=0; i<nCenter; ++i)              //畫出群心位置
        {            
            T* c = kmean.center[i];
            cvCircle (img, cvPoint (c[0],c[1]), 3, 
                white, CV_FILLED, CV_AA, 0);
            sprintf (txt, "%d", i);            //顯示群心標號
            fontObj.show (img, c[0]+3, c[1], txt); 
        }
        cvShowImage ("clusters", img );
        key = cvWaitKey(0);
        if (key == 27 || key == 'q' || key == 'Q')  //ESC
            bEnd = true;
    }

    void start (int nCenter_, int nData_, int dim_, double threshold)
    {
        bEnd    = false;
        nCenter = nCenter_;
        nData   = nData_;
        dim     = dim_;

        if (nCenter > MAX_CLUSTERS-2) {
            cout << "群心數目不得大於 " << MAX_CLUSTERS-2;
            return;
        }

        generate_sample ();
        kmean.setup (nCenter, data, nData, dim);

        while (!bEnd && 
               kmean.iterater_cluster() > threshold)
            show();
    }
};


int main ()
{
    KMeanTest<double> test;
    test.start(24, 10000, 2, 5);
}

沒有留言:

張貼留言