@@ -32,40 +32,55 @@ def get_hog_features(trainset):
3232def  Predict (testset ,trainset ,train_labels ):
3333    predict  =  []
3434    count  =  0 
35+ 
3536    for  test_vec  in  testset :
37+         # 输出当前运行的测试用例坐标,用于测试 
3638        print  count 
3739        count  +=  1 
3840
39-         knn_list  =  []
41+         knn_list  =  []       # 当前k个最近邻居 
42+         max_index  =  - 1       # 当前k个最近邻居中距离最远点的坐标 
43+         max_dist  =  0         # 当前k个最近邻居中距离最远点的距离 
4044
41-         for  i  in  range (len (train_labels )):
45+         # 先将前k个点放入k个最近邻居中,填充满knn_list 
46+         for  i  in  range (k ):
4247            label  =  train_labels [i ]
4348            train_vec  =  trainset [i ]
4449
45-             dist  =  np .linalg .norm (train_vec  -  test_vec )
50+             dist  =  np .linalg .norm (train_vec  -  test_vec )          # 计算两个点的欧氏距离 
4651
47-             if  len (knn_list ) <  k :                               # 如果还不够10个邻近点则直接添加即可 
48-                 knn_list .append ((dist ,label ))
49-             else :
50-                 max_index  =  - 1 
51-                 max_dist  =  dist 
52+             knn_list .append ((dist ,label ))
53+ 
54+         # 剩下的点 
55+         for  i  in  range (k ,len (train_labels )):
56+             label  =  train_labels [i ]
57+             train_vec  =  trainset [i ]
58+ 
59+             dist  =  np .linalg .norm (train_vec  -  test_vec )         # 计算两个点的欧氏距离 
5260
53-                 # 寻找10个邻近点钟距离最远的点 
61+             # 寻找10个邻近点钟距离最远的点 
62+             if  max_index  <  0 :
5463                for  j  in  range (k ):
5564                    if  max_dist  <  knn_list [j ][0 ]:
5665                        max_index  =  j 
5766                        max_dist  =  knn_list [max_index ][0 ]
5867
59-                 if  max_index  >=  0 :
60-                     knn_list [max_index ] =  (dist ,label )
68+             # 如果当前k个最近邻居中存在点距离比当前点距离远,则替换 
69+             if  dist  <  max_dist :
70+                 knn_list [max_index ] =  (dist ,label )
71+                 max_index  =  - 1 
72+ 
6173
74+         # 统计选票 
6275        class_total  =  10 
6376        class_count  =  [0  for  i  in  range (class_total )]
6477        for  dist ,label  in  knn_list :
6578            class_count [label ] +=  1 
6679
80+         # 找出最大选票 
6781        mmax  =  max (class_count )
6882
83+         # 找出最大选票标签 
6984        for  i  in  range (class_total ):
7085            if  mmax  ==  class_count [i ]:
7186                predict .append (i )
0 commit comments