Skip to content

Commit b1a769c

Browse files
parth-paradkarcclauss
authored andcommitted
Add pure implementation of K-Nearest Neighbours (#1278)
* Pure implementation of KNN added * Comments and test case added * doctest added
1 parent 0a7d387 commit b1a769c

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed
+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import numpy as np
2+
from collections import Counter
3+
from sklearn import datasets
4+
from sklearn.model_selection import train_test_split
5+
6+
data = datasets.load_iris()
7+
8+
X = np.array(data['data'])
9+
y = np.array(data['target'])
10+
classes = data['target_names']
11+
12+
X_train, X_test, y_train, y_test = train_test_split(X, y)
13+
14+
def euclidean_distance(a, b):
15+
"""
16+
Gives the euclidean distance between two points
17+
>>> euclidean_distance([0, 0], [3, 4])
18+
5.0
19+
>>> euclidean_distance([1, 2, 3], [1, 8, 11])
20+
10.0
21+
"""
22+
return np.linalg.norm(np.array(a) - np.array(b))
23+
24+
def classifier(train_data, train_target, classes, point, k=5):
25+
"""
26+
Classifies the point using the KNN algorithm
27+
k closest points are found (ranked in ascending order of euclidean distance)
28+
Params:
29+
:train_data: Set of points that are classified into two or more classes
30+
:train_target: List of classes in the order of train_data points
31+
:classes: Labels of the classes
32+
:point: The data point that needs to be classifed
33+
34+
>>> X_train = [[0, 0], [1, 0], [0, 1], [0.5, 0.5], [3, 3], [2, 3], [3, 2]]
35+
>>> y_train = [0, 0, 0, 0, 1, 1, 1]
36+
>>> classes = ['A','B']; point = [1.2,1.2]
37+
>>> classifier(X_train, y_train, classes,point)
38+
'A'
39+
"""
40+
data = zip(train_data, train_target)
41+
# List of distances of all points from the point to be classified
42+
distances = []
43+
for data_point in data:
44+
distance = euclidean_distance(data_point[0], point)
45+
distances.append((distance, data_point[1]))
46+
# Choosing 'k' points with the least distances.
47+
votes = [i[1] for i in sorted(distances)[:k]]
48+
# Most commonly occuring class among them
49+
# is the class into which the point is classified
50+
result = Counter(votes).most_common(1)[0][0]
51+
return classes[result]
52+
53+
54+
if __name__ == "__main__":
55+
print(classifier(X_train, y_train, classes, [4.4, 3.1, 1.3, 1.4]))

0 commit comments

Comments
 (0)