11from learning import parse_csv , weighted_mode , weighted_replicate , DataSet , \
22 PluralityLearner , NaiveBayesLearner , NearestNeighborLearner , \
33 NeuralNetLearner , PerceptronLearner , DecisionTreeLearner , \
4- euclidean_distance
4+ euclidean_distance , grade_learner
55from utils import DataFile
66
77
88
99def test_euclidean ():
10- distance = euclidean_distance ([1 ,2 ], [3 ,4 ])
10+ distance = euclidean_distance ([1 , 2 ], [3 , 4 ])
1111 assert round (distance , 2 ) == 2.83
1212
13- distance = euclidean_distance ([1 ,2 , 3 ], [4 ,5 , 6 ])
13+ distance = euclidean_distance ([1 , 2 , 3 ], [4 , 5 , 6 ])
1414 assert round (distance , 2 ) == 5.2
1515
16- distance = euclidean_distance ([0 ,0 , 0 ], [0 ,0 , 0 ])
16+ distance = euclidean_distance ([0 , 0 , 0 ], [0 , 0 , 0 ])
1717 assert distance == 0
1818
1919
@@ -24,7 +24,7 @@ def test_exclude():
2424
2525def test_parse_csv ():
2626 Iris = DataFile ('iris.csv' ).read ()
27- assert parse_csv (Iris )[0 ] == [5.1 ,3.5 ,1.4 ,0.2 ,'setosa' ]
27+ assert parse_csv (Iris )[0 ] == [5.1 , 3.5 , 1.4 , 0.2 ,'setosa' ]
2828
2929
3030def test_weighted_mode ():
@@ -47,25 +47,25 @@ def test_naive_bayes():
4747
4848 # Discrete
4949 nBD = NaiveBayesLearner (iris )
50- assert nBD ([5 ,3 , 1 , 0.1 ]) == "setosa"
50+ assert nBD ([5 , 3 , 1 , 0.1 ]) == "setosa"
5151
5252
5353def test_k_nearest_neighbors ():
5454 iris = DataSet (name = "iris" )
5555
5656 kNN = NearestNeighborLearner (iris ,k = 3 )
57- assert kNN ([5 ,3 , 1 , 0.1 ]) == "setosa"
58- assert kNN ([6 ,5 , 3 , 1.5 ]) == "versicolor"
59- assert kNN ([7.5 ,4 , 6 , 2 ]) == "virginica"
57+ assert kNN ([5 , 3 , 1 , 0.1 ]) == "setosa"
58+ assert kNN ([6 , 5 , 3 , 1.5 ]) == "versicolor"
59+ assert kNN ([7.5 , 4 , 6 , 2 ]) == "virginica"
6060
6161
6262def test_decision_tree_learner ():
6363 iris = DataSet (name = "iris" )
6464
6565 dTL = DecisionTreeLearner (iris )
66- assert dTL ([5 ,3 , 1 , 0.1 ]) == "setosa"
67- assert dTL ([6 ,5 , 3 , 1.5 ]) == "versicolor"
68- assert dTL ([7.5 ,4 , 6 , 2 ]) == "virginica"
66+ assert dTL ([5 , 3 , 1 , 0.1 ]) == "setosa"
67+ assert dTL ([6 , 5 , 3 , 1.5 ]) == "versicolor"
68+ assert dTL ([7.5 , 4 , 6 , 2 ]) == "virginica"
6969
7070
7171def test_neural_network_learner ():
@@ -75,14 +75,11 @@ def test_neural_network_learner():
7575 iris .classes_to_numbers (classes )
7676
7777 nNL = NeuralNetLearner (iris , [5 ], 0.15 , 75 )
78- pred1 = nNL ([5 ,3 , 1 , 0.1 ])
79- pred2 = nNL ([6 ,3 , 3 , 1.5 ])
80- pred3 = nNL ([7.5 ,4 , 6 , 2 ])
78+ tests = [ ([5 , 3 , 1 , 0.1 ], 0 ),
79+ ([6 , 3 , 3 , 1.5 ], 1 ),
80+ ([7.5 , 4 , 6 , 2 ], 2 )]
8181
82- # NeuralNetLearner might be wrong. If it is, check if prediction is in range.
83- assert pred1 == 0 or pred1 in range (len (classes ))
84- assert pred2 == 1 or pred2 in range (len (classes ))
85- assert pred3 == 2 or pred3 in range (len (classes ))
82+ assert grade_learner (nNL , tests ) >= 2
8683
8784
8885def test_perceptron ():
@@ -92,11 +89,8 @@ def test_perceptron():
9289 classes_number = len (iris .values [iris .target ])
9390
9491 perceptron = PerceptronLearner (iris )
95- pred1 = perceptron ([5 ,3 ,1 ,0.1 ])
96- pred2 = perceptron ([6 ,3 ,4 ,1 ])
97- pred3 = perceptron ([7.5 ,4 ,6 ,2 ])
98-
99- # PerceptronLearner might be wrong. If it is, check if prediction is in range.
100- assert pred1 == 0 or pred1 in range (classes_number )
101- assert pred2 == 1 or pred2 in range (classes_number )
102- assert pred3 == 2 or pred3 in range (classes_number )
92+ tests = [([5 , 3 , 1 , 0.1 ], 0 ),
93+ ([6 , 3 , 4 , 1.1 ], 1 ),
94+ ([7.5 , 4 , 6 , 2 ], 2 )]
95+
96+ assert grade_learner (perceptron , tests ) >= 2
0 commit comments