Skip to content

Commit f4d99fe

Browse files
antmarakisnorvig
authored andcommitted
Update test_learning.py (aimacode#534)
1 parent d57231c commit f4d99fe

File tree

1 file changed

+35
-33
lines changed

1 file changed

+35
-33
lines changed

tests/test_learning.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
21
import pytest
32
import math
3+
import random
44
from utils import DataFile
55
from learning import *
66

77

8+
random.seed("aima-python")
9+
10+
811
def test_euclidean():
912
distance = euclidean_distance([1, 2], [3, 4])
1013
assert round(distance, 2) == 2.83
@@ -15,6 +18,34 @@ def test_euclidean():
1518
distance = euclidean_distance([0, 0, 0], [0, 0, 0])
1619
assert distance == 0
1720

21+
def test_rms_error():
22+
assert rms_error([2, 2], [2, 2]) == 0
23+
assert rms_error((0, 0), (0, 1)) == math.sqrt(0.5)
24+
assert rms_error((1, 0), (0, 1)) == 1
25+
assert rms_error((0, 0), (0, -1)) == math.sqrt(0.5)
26+
assert rms_error((0, 0.5), (0, -0.5)) == math.sqrt(0.5)
27+
28+
def test_manhattan_distance():
29+
assert manhattan_distance([2, 2], [2, 2]) == 0
30+
assert manhattan_distance([0, 0], [0, 1]) == 1
31+
assert manhattan_distance([1, 0], [0, 1]) == 2
32+
assert manhattan_distance([0, 0], [0, -1]) == 1
33+
assert manhattan_distance([0, 0.5], [0, -0.5]) == 1
34+
35+
def test_mean_boolean_error():
36+
assert mean_boolean_error([1, 1], [0, 0]) == 1
37+
assert mean_boolean_error([0, 1], [1, 0]) == 1
38+
assert mean_boolean_error([1, 1], [0, 1]) == 0.5
39+
assert mean_boolean_error([0, 0], [0, 0]) == 0
40+
assert mean_boolean_error([1, 1], [1, 1]) == 0
41+
42+
def test_mean_error():
43+
assert mean_error([2, 2], [2, 2]) == 0
44+
assert mean_error([0, 0], [0, 1]) == 0.5
45+
assert mean_error([1, 0], [0, 1]) == 1
46+
assert mean_error([0, 0], [0, -1]) == 0.5
47+
assert mean_error([0, 0.5], [0, -0.5]) == 0.5
48+
1849

1950
def test_exclude():
2051
iris = DataSet(name='iris', exclude=[3])
@@ -23,7 +54,7 @@ def test_exclude():
2354

2455
def test_parse_csv():
2556
Iris = DataFile('iris.csv').read()
26-
assert parse_csv(Iris)[0] == [5.1, 3.5, 1.4, 0.2,'setosa']
57+
assert parse_csv(Iris)[0] == [5.1, 3.5, 1.4, 0.2, 'setosa']
2758

2859

2960
def test_weighted_mode():
@@ -74,39 +105,11 @@ def test_naive_bayes():
74105
def test_k_nearest_neighbors():
75106
iris = DataSet(name="iris")
76107
kNN = NearestNeighborLearner(iris,k=3)
77-
assert kNN([5,3,1,0.1]) == "setosa"
108+
assert kNN([5, 3, 1, 0.1]) == "setosa"
78109
assert kNN([5, 3, 1, 0.1]) == "setosa"
79110
assert kNN([6, 5, 3, 1.5]) == "versicolor"
80111
assert kNN([7.5, 4, 6, 2]) == "virginica"
81112

82-
def test_rms_error():
83-
assert rms_error([2,2], [2,2]) == 0
84-
assert rms_error((0,0), (0,1)) == math.sqrt(0.5)
85-
assert rms_error((1,0), (0,1)) == 1
86-
assert rms_error((0,0), (0,-1)) == math.sqrt(0.5)
87-
assert rms_error((0,0.5), (0,-0.5)) == math.sqrt(0.5)
88-
89-
def test_manhattan_distance():
90-
assert manhattan_distance([2,2], [2,2]) == 0
91-
assert manhattan_distance([0,0], [0,1]) == 1
92-
assert manhattan_distance([1,0], [0,1]) == 2
93-
assert manhattan_distance([0,0], [0,-1]) == 1
94-
assert manhattan_distance([0,0.5], [0,-0.5]) == 1
95-
96-
def test_mean_boolean_error():
97-
assert mean_boolean_error([1,1], [0,0]) == 1
98-
assert mean_boolean_error([0,1], [1,0]) == 1
99-
assert mean_boolean_error([1,1], [0,1]) == 0.5
100-
assert mean_boolean_error([0,0], [0,0]) == 0
101-
assert mean_boolean_error([1,1], [1,1]) == 0
102-
103-
def test_mean_error():
104-
assert mean_error([2,2], [2,2]) == 0
105-
assert mean_error([0,0], [0,1]) == 0.5
106-
assert mean_error([1,0], [0,1]) == 1
107-
assert mean_error([0,0], [0,-1]) == 0.5
108-
assert mean_error([0,0.5], [0,-0.5]) == 0.5
109-
110113

111114
def test_decision_tree_learner():
112115
iris = DataSet(name="iris")
@@ -118,7 +121,7 @@ def test_decision_tree_learner():
118121

119122
def test_neural_network_learner():
120123
iris = DataSet(name="iris")
121-
classes = ["setosa","versicolor","virginica"]
124+
classes = ["setosa", "versicolor", "virginica"]
122125
iris.classes_to_numbers(classes)
123126
nNL = NeuralNetLearner(iris, [5], 0.15, 75)
124127
tests = [([5, 3, 1, 0.1], 0),
@@ -154,4 +157,3 @@ def test_random_weights():
154157
assert len(test_weights) == num_weights
155158
for weight in test_weights:
156159
assert weight >= min_value and weight <= max_value
157-

0 commit comments

Comments
 (0)