@@ -32,40 +32,55 @@ def get_hog_features(trainset):
3232def Predict (testset ,trainset ,train_labels ):
3333 predict = []
3434 count = 0
35+
3536 for test_vec in testset :
37+ # 输出当前运行的测试用例坐标,用于测试
3638 print count
3739 count += 1
3840
39- knn_list = []
41+ knn_list = [] # 当前k个最近邻居
42+ max_index = - 1 # 当前k个最近邻居中距离最远点的坐标
43+ max_dist = 0 # 当前k个最近邻居中距离最远点的距离
4044
41- for i in range (len (train_labels )):
45+ # 先将前k个点放入k个最近邻居中,填充满knn_list
46+ for i in range (k ):
4247 label = train_labels [i ]
4348 train_vec = trainset [i ]
4449
45- dist = np .linalg .norm (train_vec - test_vec )
50+ dist = np .linalg .norm (train_vec - test_vec ) # 计算两个点的欧氏距离
4651
47- if len (knn_list ) < k : # 如果还不够10个邻近点则直接添加即可
48- knn_list .append ((dist ,label ))
49- else :
50- max_index = - 1
51- max_dist = dist
52+ knn_list .append ((dist ,label ))
53+
54+ # 剩下的点
55+ for i in range (k ,len (train_labels )):
56+ label = train_labels [i ]
57+ train_vec = trainset [i ]
58+
59+ dist = np .linalg .norm (train_vec - test_vec ) # 计算两个点的欧氏距离
5260
53- # 寻找10个邻近点钟距离最远的点
61+ # 寻找10个邻近点钟距离最远的点
62+ if max_index < 0 :
5463 for j in range (k ):
5564 if max_dist < knn_list [j ][0 ]:
5665 max_index = j
5766 max_dist = knn_list [max_index ][0 ]
5867
59- if max_index >= 0 :
60- knn_list [max_index ] = (dist ,label )
68+ # 如果当前k个最近邻居中存在点距离比当前点距离远,则替换
69+ if dist < max_dist :
70+ knn_list [max_index ] = (dist ,label )
71+ max_index = - 1
72+
6173
74+ # 统计选票
6275 class_total = 10
6376 class_count = [0 for i in range (class_total )]
6477 for dist ,label in knn_list :
6578 class_count [label ] += 1
6679
80+ # 找出最大选票
6781 mmax = max (class_count )
6882
83+ # 找出最大选票标签
6984 for i in range (class_total ):
7085 if mmax == class_count [i ]:
7186 predict .append (i )
0 commit comments