22
33Object clustering with k-mean algorithm 
44
5- 
65author: Atsushi Sakai (@Atsushi_twi) 
76
87""" 
98
9+ import  numpy  as  np 
10+ import  math 
1011import  matplotlib .pyplot  as  plt 
1112import  random 
1213
1314
14- class  Cluster :
15+ class  Clusters :
16+ 
17+     def  __init__ (self , x , y , nlabel ):
18+         self .x  =  x 
19+         self .y  =  y 
20+         self .ndata  =  len (self .x )
21+         self .nlabel  =  nlabel 
22+         self .labels  =  [random .randint (0 , nlabel  -  1 )
23+                        for  _  in  range (self .ndata )]
24+         self .cx  =  [0.0  for  _  in  range (nlabel )]
25+         self .cy  =  [0.0  for  _  in  range (nlabel )]
26+ 
27+ 
28+ def  init_clusters (rx , ry , nc ):
29+ 
30+     clusters  =  Clusters (rx , ry , nc )
31+ 
32+     return  clusters 
33+ 
34+ 
35+ def  calc_centroid (clusters ):
36+ 
37+     for  ic  in  range (clusters .nlabel ):
38+         x , y  =  calc_labeled_points (ic , clusters )
39+         ndata  =  len (x )
40+         clusters .cx [ic ] =  sum (x ) /  ndata 
41+         clusters .cy [ic ] =  sum (y ) /  ndata 
42+ 
43+     return  clusters 
1544
16-     def  __init__ (self ):
17-         self .x  =  []
18-         self .y  =  []
19-         self .cx  =  None 
20-         self .cy  =  None 
45+ 
46+ def  update_clusters (clusters ):
47+     cost  =  0.0 
48+ 
49+     for  ip  in  range (clusters .ndata ):
50+         px  =  clusters .x [ip ]
51+         py  =  clusters .y [ip ]
52+ 
53+         dx  =  [icx  -  px  for  icx  in  clusters .cx ]
54+         dy  =  [icy  -  py  for  icy  in  clusters .cy ]
55+ 
56+         dlist  =  [math .sqrt (idx ** 2  +  idy ** 2 ) for  (idx , idy ) in  zip (dx , dy )]
57+         mind  =  min (dlist )
58+         min_id  =  dlist .index (mind )
59+         clusters .labels [ip ] =  min_id 
60+         cost  +=  min_id 
61+ 
62+     return  clusters , cost 
2163
2264
2365def  kmean_clustering (rx , ry , nc ):
2466
25-     minx ,  maxx   =   min (rx ),  max ( rx )
26-     miny ,  maxy   =   min ( ry ),  max ( ry )
67+     clusters   =   init_clusters (rx ,  ry ,  nc )
68+     clusters   =   calc_centroid ( clusters )
2769
28-     clusters  =  [Cluster () for  i  in  range (nc )]
70+     MAX_LOOP  =  10 
71+     DCOST_TH  =  1.0 
72+     pcost  =  100.0 
73+     for  loop  in  range (MAX_LOOP ):
74+         print ("Loop:" , loop )
75+         clusters , cost  =  update_clusters (clusters )
76+         clusters  =  calc_centroid (clusters )
2977
30-     for  c  in  clusters :
31-         c .cx  =  random .uniform (minx , maxx )
32-         c .cy  =  random .uniform (miny , maxy )
78+         dcost  =  abs (cost  -  pcost )
79+         if  dcost  <  DCOST_TH :
80+             break 
81+         pcost  =  cost 
3382
3483    return  clusters 
3584
@@ -40,17 +89,30 @@ def calc_raw_data():
4089
4190    cx  =  [0.0 , 5.0 ]
4291    cy  =  [0.0 , 5.0 ]
43-     np  =  30 
92+     npoints  =  30 
4493    rand_d  =  3.0 
4594
4695    for  (icx , icy ) in  zip (cx , cy ):
47-         for  _  in  range (np ):
96+         for  _  in  range (npoints ):
4897            rx .append (icx  +  rand_d  *  (random .random () -  0.5 ))
4998            ry .append (icy  +  rand_d  *  (random .random () -  0.5 ))
5099
51100    return  rx , ry 
52101
53102
103+ def  calc_labeled_points (ic , clusters ):
104+ 
105+     inds  =  np .array ([i  for  i  in  range (clusters .ndata )
106+                      if  clusters .labels [i ] ==  ic ])
107+     tx  =  np .array (clusters .x )
108+     ty  =  np .array (clusters .y )
109+ 
110+     x  =  tx [inds ]
111+     y  =  ty [inds ]
112+ 
113+     return  x , y 
114+ 
115+ 
54116def  main ():
55117    print (__file__  +  " start!!" )
56118
@@ -59,11 +121,10 @@ def main():
59121    ncluster  =  2 
60122    clusters  =  kmean_clustering (rx , ry , ncluster )
61123
62-     for  c  in  clusters :
63-         print (c .cx , c .cy )
64-         plt .plot (c .cx , c .cy , "x" )
65- 
66-     plt .plot (rx , ry , "." )
124+     for  ic  in  range (clusters .nlabel ):
125+         x , y  =  calc_labeled_points (ic , clusters )
126+         plt .plot (x , y , "x" )
127+     plt .plot (clusters .cx , clusters .cy , "o" )
67128    plt .show ()
68129
69130
0 commit comments