Skip to content

Commit 17fac54

Browse files
antmarakisnorvig
authored andcommitted
Learning: Grade Learner (aimacode#496)
* Add grade_learner * Update test_learning.py
1 parent fb503e6 commit 17fac54

File tree

2 files changed

+31
-27
lines changed

2 files changed

+31
-27
lines changed

learning.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,16 @@ def score(learner, size):
908908
return [(size, mean([score(learner, size) for t in range(trials)]))
909909
for size in sizes]
910910

911+
912+
def grade_learner(predict, tests):
913+
"""Grades the given learner based on how many tests it passes.
914+
tests is a list with each element in the form: (values, output)."""
915+
correct = 0
916+
for t in tests:
917+
if predict(t[0]) == t[1]:
918+
correct += 1
919+
return correct
920+
911921
# ______________________________________________________________________________
912922
# The rest of this file gives datasets for machine learning problems.
913923

tests/test_learning.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
from learning import parse_csv, weighted_mode, weighted_replicate, DataSet, \
22
PluralityLearner, NaiveBayesLearner, NearestNeighborLearner, \
33
NeuralNetLearner, PerceptronLearner, DecisionTreeLearner, \
4-
euclidean_distance
4+
euclidean_distance, grade_learner
55
from utils import DataFile
66

77

88

99
def test_euclidean():
10-
distance = euclidean_distance([1,2], [3,4])
10+
distance = euclidean_distance([1, 2], [3, 4])
1111
assert round(distance, 2) == 2.83
1212

13-
distance = euclidean_distance([1,2,3], [4,5,6])
13+
distance = euclidean_distance([1, 2, 3], [4, 5, 6])
1414
assert round(distance, 2) == 5.2
1515

16-
distance = euclidean_distance([0,0,0], [0,0,0])
16+
distance = euclidean_distance([0, 0, 0], [0, 0, 0])
1717
assert distance == 0
1818

1919

@@ -24,7 +24,7 @@ def test_exclude():
2424

2525
def test_parse_csv():
2626
Iris = DataFile('iris.csv').read()
27-
assert parse_csv(Iris)[0] == [5.1,3.5,1.4,0.2,'setosa']
27+
assert parse_csv(Iris)[0] == [5.1, 3.5, 1.4, 0.2,'setosa']
2828

2929

3030
def test_weighted_mode():
@@ -47,25 +47,25 @@ def test_naive_bayes():
4747

4848
# Discrete
4949
nBD = NaiveBayesLearner(iris)
50-
assert nBD([5,3,1,0.1]) == "setosa"
50+
assert nBD([5, 3, 1, 0.1]) == "setosa"
5151

5252

5353
def test_k_nearest_neighbors():
5454
iris = DataSet(name="iris")
5555

5656
kNN = NearestNeighborLearner(iris,k=3)
57-
assert kNN([5,3,1,0.1]) == "setosa"
58-
assert kNN([6,5,3,1.5]) == "versicolor"
59-
assert kNN([7.5,4,6,2]) == "virginica"
57+
assert kNN([5, 3, 1, 0.1]) == "setosa"
58+
assert kNN([6, 5, 3, 1.5]) == "versicolor"
59+
assert kNN([7.5, 4, 6, 2]) == "virginica"
6060

6161

6262
def test_decision_tree_learner():
6363
iris = DataSet(name="iris")
6464

6565
dTL = DecisionTreeLearner(iris)
66-
assert dTL([5,3,1,0.1]) == "setosa"
67-
assert dTL([6,5,3,1.5]) == "versicolor"
68-
assert dTL([7.5,4,6,2]) == "virginica"
66+
assert dTL([5, 3, 1, 0.1]) == "setosa"
67+
assert dTL([6, 5, 3, 1.5]) == "versicolor"
68+
assert dTL([7.5, 4, 6, 2]) == "virginica"
6969

7070

7171
def test_neural_network_learner():
@@ -75,14 +75,11 @@ def test_neural_network_learner():
7575
iris.classes_to_numbers(classes)
7676

7777
nNL = NeuralNetLearner(iris, [5], 0.15, 75)
78-
pred1 = nNL([5,3,1,0.1])
79-
pred2 = nNL([6,3,3,1.5])
80-
pred3 = nNL([7.5,4,6,2])
78+
tests = [([5, 3, 1, 0.1], 0),
79+
([6, 3, 3, 1.5], 1),
80+
([7.5, 4, 6, 2], 2)]
8181

82-
# NeuralNetLearner might be wrong. If it is, check if prediction is in range.
83-
assert pred1 == 0 or pred1 in range(len(classes))
84-
assert pred2 == 1 or pred2 in range(len(classes))
85-
assert pred3 == 2 or pred3 in range(len(classes))
82+
assert grade_learner(nNL, tests) >= 2
8683

8784

8885
def test_perceptron():
@@ -92,11 +89,8 @@ def test_perceptron():
9289
classes_number = len(iris.values[iris.target])
9390

9491
perceptron = PerceptronLearner(iris)
95-
pred1 = perceptron([5,3,1,0.1])
96-
pred2 = perceptron([6,3,4,1])
97-
pred3 = perceptron([7.5,4,6,2])
98-
99-
# PerceptronLearner might be wrong. If it is, check if prediction is in range.
100-
assert pred1 == 0 or pred1 in range(classes_number)
101-
assert pred2 == 1 or pred2 in range(classes_number)
102-
assert pred3 == 2 or pred3 in range(classes_number)
92+
tests = [([5, 3, 1, 0.1], 0),
93+
([6, 3, 4, 1.1], 1),
94+
([7.5, 4, 6, 2], 2)]
95+
96+
assert grade_learner(perceptron, tests) >= 2

0 commit comments

Comments
 (0)