1+ from  sklearn  import  datasets 
2+ import  numpy  as  np 
3+ from  sklearn .model_selection  import  train_test_split 
4+ ## Example 1: iris for classification( 3 classes) 
5+ # X, y = datasets.load_iris(return_X_y=True) 
6+ # Example 2 
7+ X , y  =  datasets .load_breast_cancer (return_X_y = True )
8+ 
9+ X_train , X_test , y_train , y_test  =  train_test_split (X , y , test_size = 0.3 , random_state = 42 )
10+ 
11+ # my k-NN 
12+ # kd-tree 
13+ class  KDNode :
14+     ''' 
15+     vaule: [X,y] 
16+     ''' 
17+     def  __init__ (self , value = None , parent = None , left = None , right = None , index = None ):
18+         self .value  =  value 
19+         self .parent  =  parent 
20+         self .left  =  left 
21+         self .right  =  right  
22+     @property  
23+     def  brother (self ):
24+         if  not  self .parent :
25+             bro  =  None 
26+         else :
27+             if  self .parent .left  is  self :
28+                 bro  =  self .parent .right 
29+             else :
30+                 bro  =  self .parent .left 
31+         return  bro 
32+ 
33+ class  KDTree ():
34+     def  __init__ (self ,K = 3 ):
35+         self .root  =  KDNode ()
36+         self .K  =  K 
37+         
38+     def  _build (self , data , axis = 0 ,parent = None ):
39+         ''' 
40+         data:[X,y] 
41+         ''' 
42+         # choose median point  
43+         if  len (data ) ==  0 :
44+             root  =  KDNode ()
45+             return  root 
46+         data  =  np .array (sorted (data , key = lambda  x :x [axis ]))
47+         median  =  int (len (data )/ 2 )
48+         loc  =  data [median ]
49+         root  =  KDNode (loc ,parent = parent )
50+         new_axis  =  (axis + 1 )% (len (data [0 ])- 1 )
51+         if  len (data [:median ,:]) ==  0 :
52+             root .left  =  None 
53+         else :
54+             root .left  =  self ._build (data [:median ,:],axis = new_axis ,parent = root )
55+         if  len (data [median + 1 :,:]) ==  0 :
56+             root .right  =  None 
57+         else :
58+             root .right  =  self ._build (data [median + 1 :,:],axis = new_axis ,parent = root )
59+         self .root  =  root 
60+         return  root 
61+     
62+     def  fit (self , X , y ):
63+         # concat X,y 
64+         data  =  np .concatenate ([X , y .reshape (- 1 ,1 )],axis = 1 )
65+         root  =  self ._build (data )
66+         
67+     def  _get_eu_distance (self ,arr1 :np .ndarray , arr2 :np .ndarray ) ->  float :
68+         return  ((arr1  -  arr2 ) **  2 ).sum () **  0.5 
69+         
70+     def  _search_node (self ,current ,point ,result = {},class_count = {}):
71+         # Get max_node, max_distance. 
72+         if  not  result :
73+             max_node  =  None 
74+             max_distance  =  float ('inf' )
75+         else :
76+             # find the nearest (node, distance) tuple 
77+             max_node , max_distance  =  sorted (result .items (), key = lambda  n :n [1 ],reverse = True )[0 ]
78+         node_dist  =  self ._get_eu_distance (current .value [:- 1 ],point )
79+         if  len (result ) ==  self .K :
80+             if  node_dist  <  max_distance :
81+                 result .pop (max_node )
82+                 result [current ] =  node_dist 
83+                 class_count [current .value [- 1 ]] =  class_count .get (current .value [- 1 ],0 ) +  1 
84+                 class_count [max_node .value [- 1 ]] =  class_count .get (max_node .value [- 1 ],1 ) -  1 
85+         elif  len (result ) <  self .K :
86+             result [current ] =  node_dist 
87+             class_count [current .value [- 1 ]] =  class_count .get (current .value [- 1 ],0 ) +  1 
88+         return  result ,class_count 
89+         
90+     def  search (self ,point ):
91+         # find the point belongs to which leaf node(current). 
92+         current  =  self .root 
93+         axis  =  0 
94+         while  current :
95+             if  point [axis ] <  current .value [axis ]:
96+                 prev  =  current 
97+                 current  =  current .left 
98+             else :
99+                 prev  =  current 
100+                 current  =  current .right 
101+             axis  =  (axis + 1 )% len (point )
102+         current  =  prev 
103+         # search k nearest points 
104+         result  =  {}
105+         class_count = {}
106+         while  current :
107+             result ,class_count  =  self ._search_node (current ,point ,result ,class_count )
108+             if  current .brother :
109+                 result ,class_count  =  self ._search_node (current .brother ,point ,result ,class_count )
110+             current  =  current .parent 
111+         return  result ,class_count 
112+         
113+     def  predict (self ,X_test ):
114+         predict  =  [0  for  _  in  range (len (X_test ))]
115+         for  i  in  range (len (X_test )):
116+             _ ,class_count  =  self .search (X_test [i ])
117+             sorted_count  =  sorted (class_count .items (), key = lambda  x : x [1 ],reverse = True )
118+             predict [i ] =  sorted_count [0 ][0 ]
119+         return  predict 
120+         
121+     def  score (self ,X ,y ):
122+         y_pred  =  self .predict (X )
123+         return  1  -  np .count_nonzero (y - y_pred )/ len (y )
124+         
125+     def  print_tree (self ,X_train ,y_train ):  
126+         height  =  int (math .log (len (X_train ))/ math .log (2 ))
127+         max_width  =  pow (2 , height )
128+         node_width  =  2 
129+         in_level  =  1 
130+         root  =  self .fit (X_train ,y_train )
131+         from  collections  import  deque 
132+         q  =  deque ()
133+         q .append (root )
134+         while  q :
135+             count  =  len (q )
136+             width  =  int (max_width  *  node_width  /  in_level )
137+             in_level  +=  1 
138+             while  count > 0 :
139+                 node  =  q .popleft ()
140+                 if  node .left :
141+                     q .append (node .left  )
142+                 if  node .right :
143+                     q .append (node .right )
144+                 node_str  =  (str (node .value ) if  node  else  '' ).center (width )
145+                 print (node_str , end = ' ' )
146+                 count  -=  1  
147+             print ("\n " )
148+ kd  =  KDTree ()
149+ kd .fit ( X_train , y_train )
150+ print (kd .score (X_test ,y_test ))
0 commit comments