1
+ import numpy as np
2
+ from collections import Counter
3
+ from sklearn import datasets
4
+ from sklearn .model_selection import train_test_split
5
+
6
+ data = datasets .load_iris ()
7
+
8
+ X = np .array (data ['data' ])
9
+ y = np .array (data ['target' ])
10
+ classes = data ['target_names' ]
11
+
12
+ X_train , X_test , y_train , y_test = train_test_split (X , y )
13
+
14
+ def euclidean_distance (a , b ):
15
+ """
16
+ Gives the euclidean distance between two points
17
+ >>> euclidean_distance([0, 0], [3, 4])
18
+ 5.0
19
+ >>> euclidean_distance([1, 2, 3], [1, 8, 11])
20
+ 10.0
21
+ """
22
+ return np .linalg .norm (np .array (a ) - np .array (b ))
23
+
24
+ def classifier (train_data , train_target , classes , point , k = 5 ):
25
+ """
26
+ Classifies the point using the KNN algorithm
27
+ k closest points are found (ranked in ascending order of euclidean distance)
28
+ Params:
29
+ :train_data: Set of points that are classified into two or more classes
30
+ :train_target: List of classes in the order of train_data points
31
+ :classes: Labels of the classes
32
+ :point: The data point that needs to be classifed
33
+
34
+ >>> X_train = [[0, 0], [1, 0], [0, 1], [0.5, 0.5], [3, 3], [2, 3], [3, 2]]
35
+ >>> y_train = [0, 0, 0, 0, 1, 1, 1]
36
+ >>> classes = ['A','B']; point = [1.2,1.2]
37
+ >>> classifier(X_train, y_train, classes,point)
38
+ 'A'
39
+ """
40
+ data = zip (train_data , train_target )
41
+ # List of distances of all points from the point to be classified
42
+ distances = []
43
+ for data_point in data :
44
+ distance = euclidean_distance (data_point [0 ], point )
45
+ distances .append ((distance , data_point [1 ]))
46
+ # Choosing 'k' points with the least distances.
47
+ votes = [i [1 ] for i in sorted (distances )[:k ]]
48
+ # Most commonly occuring class among them
49
+ # is the class into which the point is classified
50
+ result = Counter (votes ).most_common (1 )[0 ][0 ]
51
+ return classes [result ]
52
+
53
+
54
+ if __name__ == "__main__" :
55
+ print (classifier (X_train , y_train , classes , [4.4 , 3.1 , 1.3 , 1.4 ]))
0 commit comments