Skip to content

Commit 90460bd

Browse files
committed
K-Means算法分析与python实现
K-Means算法分析与python实现,CSDN博客地址:http://blog.csdn.net/gamer_gyt/article/details/48949227 scikit-learn下使用K-means:http://blog.csdn.net/gamer_gyt/article/details/51244850
1 parent 06e5134 commit 90460bd

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

K-means/kMeans.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#coding:utf-8
2+
'''
3+
Created on 2015年11月9日
4+
5+
@author: Administrator
6+
'''
7+
import numpy as np
8+
9+
def kMeans(X, k, maxIt):
10+
#横纵坐标
11+
numPoints, numDim = X.shape
12+
dataSet = np.zeros((numPoints, numDim + 1))
13+
dataSet[: ,: -1 ] = X
14+
#随机产生新的中心点
15+
centroids = dataSet[np.random.randint(numPoints, size = k), :]
16+
#centroids = dataSet[0:2,:]
17+
#给k个中心点标签赋值为[1,k+1]
18+
centroids[:, -1] = range(1, k+1)
19+
20+
iterations = 0 #循环次数
21+
oldCentroids = None #用来储存旧的中心点
22+
23+
while not shouldStop(oldCentroids, centroids, iterations, maxIt):
24+
print "iterations:\n ",iterations
25+
print "dataSet: \n",dataSet
26+
print "centroids:\n ",centroids
27+
28+
#用copy的原因是进行复制,不用=是因为=相当于同时指向一个地址,一个改变另外一个也会改变
29+
oldCentroids = np.copy(centroids)
30+
iterations += 1
31+
32+
#更新中心点
33+
updataLabels(dataSet, centroids)
34+
35+
#得到新的中心点
36+
centroids = getCentroids(dataSet, k)
37+
38+
return dataSet
39+
40+
def shouldStop(oldCentroids, centroids, iterations, maxIt):
41+
if iterations > maxIt:
42+
return True
43+
return np.array_equal(oldCentroids, centroids)
44+
45+
46+
def updataLabels(dataSet, centroids):
47+
numPoints, numDim = dataSet.shape
48+
for i in range(0,numPoints):
49+
dataSet[i,-1] = getLabelFromCloseestCentroid(dataSet[i, :-1],centroids)
50+
51+
52+
def getLabelFromCloseestCentroid(dataSetRow, centroids):
53+
label = centroids[0, -1]
54+
#np.linalg.norm() 计算两个向量之间的距离
55+
minDist = np.linalg.norm(dataSetRow - centroids[0, :-1])
56+
for i in range(1,centroids.shape[0]):
57+
dist = np.linalg.norm(dataSetRow - centroids[i,:-1])
58+
if dist < minDist:
59+
minDist = dist
60+
label = centroids[i,-1]
61+
62+
print "minDist :\n" ,minDist
63+
return label
64+
65+
66+
def getCentroids(dataSet, k):
67+
result = np.zeros((k, dataSet.shape[1]))
68+
for i in range(1, k+1):
69+
oneCluster = dataSet[dataSet[:,-1] == i,:-1]
70+
result[i - 1, :-1] = np.mean(oneCluster, axis=0)
71+
result[i - 1, -1] = i
72+
73+
return result
74+
75+
76+
x1 = np.array([1, 1])
77+
x2 = np.array([2, 1])
78+
x3 = np.array([4, 3])
79+
x4 = np.array([5, 4])
80+
testX = np.vstack((x1, x2, x3, x4))
81+
82+
result = kMeans(testX, 2, 10)
83+
print "final result: \n",result

0 commit comments

Comments
 (0)