数据结构应用实例(二)——K均值聚类
一、问题描述
对Iris数据进行分类,数据从文件读入。Iris包含150个四维的数据,这些数据可以看做是四维空间中的点。根据这些点在空间中的位置分布,将这150个特征点分成三类,分类的依据是欧氏距离,同类点之间的距离比较小;反之,不同类别的点之间的距离会比较大。
二、代码实现
#include<stdio.h>
#include<stdlib.h>
#include<math.h>
#define N 150 //数据点个数
#pragma warning(disable:4996)//设置全局变量
float point[N+1][5];//存放数据点,0号位置不存放数据
int nums[3];//存放同类节点个数,nums[i-1]表示第i类节点个数
int group[3][N];//存放同类节点编号,group[i-1][0,,,num[i]-1]存放第i类节点编号float aver[3][5];//同类点的算术平均中心,aver[i-1][0]表示i类点的个数,aver[i-1][1,,,4]表示i类中心点的坐标
float centers[3][5];//存储聚类中心,centers[i-1][1,,,4]表示i类中心点的坐标void getdata();//从文件中读取数据,存放在数组point中,0号位置不放数据
float dis(float *a,float *b);//计算a,b两点间的距离
void classify();//对数据点进行分类
void calAver();//计算同类的算术平均中心main()
{int i,j;int count=0;//迭代次数//1.初始化getdata();//读取数据//利用前三个点初始化聚类中心for (i = 1; i <= 3; i++){for (j = 1; j <= 4; j++)centers[i-1][j] = point[i][j];}//2.对数据点进行分类classify();//3.计算算术平均中心calAver();//4.如果算术平均中心和聚类中心不重合且迭代次数没有超过给定值,继续迭代while (dis(aver[0],centers[0])+dis(aver[1],centers[1])+dis(aver[2],centers[2])!=0.0 && count<=50){ count++;//迭代次数增加//更新聚类中心for (i = 0; i < 3; i++){for (j = 1; j <= 4; j++)centers[i][j] = aver[i][j];}//重新分类classify();//再次计算算术平均中心calAver();}//5.显示结果if (dis(aver[0],centers[0])+dis(aver[1],centers[1])+dis(aver[2],centers[2]) != 0.0)printf("迭代次数超过最大值,该组数据不能分成三类,程序结束.\n");else{printf("迭代次数为:%d.\n\n", count);for (i = 1; i <= 3; i++)//分类显示,对于第i类{printf("%d类中心为:", i);for (j = 1; j <= 4; j++)printf("%.2f ", centers[i-1][j]);printf("\n共%d个数据点,编号为:\n", nums[i-1]);for (j = 0; j < nums[i - 1]; j++){printf("%4d",group[i-1][j]);if((j+1)%10==0)printf("\n");}printf("\n");if(j%10!=0)printf("\n");}}
}void getdata()//从文件中读取数据,存放在数组point中,0号位置不放数据
{int i;FILE *fp=fopen("Iris.txt","r");if (!fp){printf("Iris.txt 文件读取失败.\n");exit(-1);}//从文件中读取数据for(i=1;(!feof(fp))&&(i<=N);i++)fscanf(fp, "%f%f%f%f",&(point[i][1]),&(point[i][2]),&(point[i][3]),&(point[i][4]));fclose(fp);
}float dis(float *a, float *b)//计算a,b两点间的距离
{return sqrt((a[1]-b[1])*(a[1]-b[1])+(a[2]-b[2])*(a[2]-b[2])+(a[3]-b[3])*(a[3]-b[3])+(a[4]-b[4])*(a[4]-b[4]));
}void classify()//对数据点进行分类
{int i,temp;float dis1,dis2,dis3;//1.先清空原有分类for(i=0;i < 3;i++)nums[i]=0;for (i = 1; i <= N; i++){//2.计算距三个中心点的距离dis1=dis(centers[0],point[i]);dis2=dis(centers[1],point[i]);dis3=dis(centers[2],point[i]);//3.分类if (dis1 <= dis2 && dis1 <= dis3)//属于1类temp=1;else if(dis2 <= dis3)//属于2类temp=2;else//属于3类temp=3;group[temp-1][nums[temp-1]]=i;nums[temp-1]++;}
}void calAver()//计算同类的算术平均中心
{int i,j,h;int index;//1.置零for (i = 1; i <= 3; i++){for (j = 1; j <= 4; j++)aver[i-1][j] = 0;}//2.对三类节点各个分量求平均for (i = 1; i <= 3; i++){if (nums[i-1] == 0){printf("第%d类不含数据点.\n",i);exit(-1);}for (j = 0; j < nums[i-1]; j++)//nums[i-1]表示第i类节点数目{index=group[i-1][j];//节点编号for (h = 1; h <= 4; h++)//四个分量累加aver[i-1][h] += point[index][h];} for (h = 1; h <= 4; h++)//四个分量和除以个数得到平均值aver[i-1][h] /= nums[i-1];}
}
三、小结
1、 用二维数组 point 存储数据点,点在数组中的位置就是点的编号;
2、 设置二维数组 group 存放各类节点编号,nums 存放各类节点个数,在计算平均值的时候可以直接实现数据点的随机读取,大大提升算法的效率;
3、 几乎将全部的变量设置成全局变量,避免参数传递;
4、 如果在某次分类之后,某一类不含节点,当即进行提示,并结束程序,防止出现除0操作;
5、 为防止结果不收敛,设定最大迭代次数进行控制;
6、 如果结果收敛,最后得到的算术平均中心和聚类中心完全重合,没有偏差,因为最后一次迭代前后,所有数据点的类别没有发生变化;