Skip to content

Commit f8b1d0d

Browse files
committed
Add knn
1 parent bffbac1 commit f8b1d0d

File tree

2 files changed

+186
-0
lines changed

2 files changed

+186
-0
lines changed

k近邻法/brute_knn.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from sklearn import datasets
2+
import numpy as np
3+
from sklearn.model_selection import train_test_split
4+
## Example 1: iris for classification( 3 classes)
5+
# X, y = datasets.load_iris(return_X_y=True)
6+
# Example 2
7+
X, y = datasets.load_breast_cancer(return_X_y=True)
8+
9+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
10+
# my k-NN
11+
class KNN():
12+
def __init__(self,data,K=3):
13+
self.X = data[0]
14+
self.y = data[1]
15+
self.K = K
16+
def fit(self, X_test):
17+
diffX = np.repeat(X_test, len(self.X), axis=0) - np.tile(self.X,(len(X_test),1))
18+
square_diffX = (diffX**2).sum(axis=1).reshape(len(X_test),len(self.X))
19+
sorted_index = square_diffX.argsort()
20+
predict = [0 for _ in range(len(X_test))]
21+
for i in range(len(X_test)):
22+
class_count={}
23+
for j in range(self.K):
24+
vote_label = self.y[sorted_index[i][j]]
25+
class_count[vote_label] = class_count.get(vote_label,0) + 1
26+
sorted_count = sorted(class_count.items(), key=lambda x: x[1],reverse=True)
27+
predict[i] = sorted_count[0][0]
28+
return predict
29+
def predict(self, X_test):
30+
return self.fit(X_test)
31+
def score(self,X,y):
32+
y_pred = self.predict(X)
33+
return 1 - np.count_nonzero(y-y_pred)/len(y)
34+
35+
knn = KNN((X_train,y_train), K=3)
36+
print(knn.score(X_test,y_test))

k近邻法/kdtree_knn.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from sklearn import datasets
2+
import numpy as np
3+
from sklearn.model_selection import train_test_split
4+
## Example 1: iris for classification( 3 classes)
5+
# X, y = datasets.load_iris(return_X_y=True)
6+
# Example 2
7+
X, y = datasets.load_breast_cancer(return_X_y=True)
8+
9+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
10+
11+
# my k-NN
12+
# kd-tree
13+
class KDNode:
14+
'''
15+
vaule: [X,y]
16+
'''
17+
def __init__(self, value=None, parent=None, left=None, right=None, index=None):
18+
self.value = value
19+
self.parent = parent
20+
self.left = left
21+
self.right = right
22+
@property
23+
def brother(self):
24+
if not self.parent:
25+
bro = None
26+
else:
27+
if self.parent.left is self:
28+
bro = self.parent.right
29+
else:
30+
bro = self.parent.left
31+
return bro
32+
33+
class KDTree():
34+
def __init__(self,K=3):
35+
self.root = KDNode()
36+
self.K = K
37+
38+
def _build(self, data, axis=0,parent=None):
39+
'''
40+
data:[X,y]
41+
'''
42+
# choose median point
43+
if len(data) == 0:
44+
root = KDNode()
45+
return root
46+
data = np.array(sorted(data, key=lambda x:x[axis]))
47+
median = int(len(data)/2)
48+
loc = data[median]
49+
root = KDNode(loc,parent=parent)
50+
new_axis = (axis+1)%(len(data[0])-1)
51+
if len(data[:median,:]) == 0:
52+
root.left = None
53+
else:
54+
root.left = self._build(data[:median,:],axis=new_axis,parent=root)
55+
if len(data[median+1:,:]) == 0:
56+
root.right = None
57+
else:
58+
root.right = self._build(data[median+1:,:],axis=new_axis,parent=root)
59+
self.root = root
60+
return root
61+
62+
def fit(self, X, y):
63+
# concat X,y
64+
data = np.concatenate([X, y.reshape(-1,1)],axis=1)
65+
root = self._build(data)
66+
67+
def _get_eu_distance(self,arr1:np.ndarray, arr2:np.ndarray) -> float:
68+
return ((arr1 - arr2) ** 2).sum() ** 0.5
69+
70+
def _search_node(self,current,point,result={},class_count={}):
71+
# Get max_node, max_distance.
72+
if not result:
73+
max_node = None
74+
max_distance = float('inf')
75+
else:
76+
# find the nearest (node, distance) tuple
77+
max_node, max_distance = sorted(result.items(), key=lambda n:n[1],reverse=True)[0]
78+
node_dist = self._get_eu_distance(current.value[:-1],point)
79+
if len(result) == self.K:
80+
if node_dist < max_distance:
81+
result.pop(max_node)
82+
result[current] = node_dist
83+
class_count[current.value[-1]] = class_count.get(current.value[-1],0) + 1
84+
class_count[max_node.value[-1]] = class_count.get(max_node.value[-1],1) - 1
85+
elif len(result) < self.K:
86+
result[current] = node_dist
87+
class_count[current.value[-1]] = class_count.get(current.value[-1],0) + 1
88+
return result,class_count
89+
90+
def search(self,point):
91+
# find the point belongs to which leaf node(current).
92+
current = self.root
93+
axis = 0
94+
while current:
95+
if point[axis] < current.value[axis]:
96+
prev = current
97+
current = current.left
98+
else:
99+
prev = current
100+
current = current.right
101+
axis = (axis+1)%len(point)
102+
current = prev
103+
# search k nearest points
104+
result = {}
105+
class_count={}
106+
while current:
107+
result,class_count = self._search_node(current,point,result,class_count)
108+
if current.brother:
109+
result,class_count = self._search_node(current.brother,point,result,class_count)
110+
current = current.parent
111+
return result,class_count
112+
113+
def predict(self,X_test):
114+
predict = [0 for _ in range(len(X_test))]
115+
for i in range(len(X_test)):
116+
_,class_count = self.search(X_test[i])
117+
sorted_count = sorted(class_count.items(), key=lambda x: x[1],reverse=True)
118+
predict[i] = sorted_count[0][0]
119+
return predict
120+
121+
def score(self,X,y):
122+
y_pred = self.predict(X)
123+
return 1 - np.count_nonzero(y-y_pred)/len(y)
124+
125+
def print_tree(self,X_train,y_train):
126+
height = int(math.log(len(X_train))/math.log(2))
127+
max_width = pow(2, height)
128+
node_width = 2
129+
in_level = 1
130+
root = self.fit(X_train,y_train)
131+
from collections import deque
132+
q = deque()
133+
q.append(root)
134+
while q:
135+
count = len(q)
136+
width = int(max_width * node_width / in_level)
137+
in_level += 1
138+
while count>0:
139+
node = q.popleft()
140+
if node.left:
141+
q.append(node.left )
142+
if node.right:
143+
q.append(node.right)
144+
node_str = (str(node.value) if node else '').center(width)
145+
print(node_str, end=' ')
146+
count -= 1
147+
print("\n")
148+
kd = KDTree()
149+
kd.fit( X_train, y_train)
150+
print(kd.score(X_test,y_test))

0 commit comments

Comments
 (0)