Skip to content

Commit 022b81a

Browse files
Added standard deviation method to StatisticsReporter.
Updated tests to match current implementation.
1 parent fa939e3 commit 022b81a

File tree

4 files changed

+27
-18
lines changed

4 files changed

+27
-18
lines changed

neat/statistics.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import copy
33
import csv
44

5-
from neat.math_util import mean
5+
from neat.math_util import mean, stdev
66
from neat.six_util import iteritems
77
from neat.reporting import BaseReporter
88

@@ -26,16 +26,23 @@ def post_evaluate(self, config, population, species, best_genome):
2626
self.generation_statistics.append(species_stats)
2727
self.generation_cross_validation_statistics.append(species_cross_validation_stats)
2828

29-
def get_average_fitness(self):
30-
"""Get the per-generation average fitness."""
31-
avg_fitness = []
29+
def get_fitness_stat(self, f):
30+
stat = []
3231
for stats in self.generation_statistics:
3332
scores = []
3433
for species_stats in stats.values():
3534
scores.extend(species_stats.values())
36-
avg_fitness.append(mean(scores))
35+
stat.append(f(scores))
36+
37+
return stat
38+
39+
def get_fitness_mean(self):
40+
"""Get the per-generation average fitness."""
41+
return self.get_fitness_stat(mean)
3742

38-
return avg_fitness
43+
def get_fitness_stdev(self):
44+
"""Get the per-generation standard deviation of the fitness."""
45+
return self.get_fitness_stat(stdev)
3946

4047
def get_average_cross_validation_fitness(self):
4148
"""Get the per-generation average cross_validation fitness."""

tests/test_activation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def assert_almost_equal(a, b):
1313
max_abs = max(abs(a), abs(b))
1414
abs_rel_err = abs(a - b) / max_abs
1515
if abs_rel_err > 1e-6:
16-
raise NotAlmostEqualException()
16+
raise NotAlmostEqualException("{0:.4f} !~= {1:.4f}".format(a, b))
1717

1818

1919
def test_sigmoid():
@@ -29,7 +29,7 @@ def test_sin():
2929

3030

3131
def test_gauss():
32-
assert_almost_equal(activations.gauss_activation(0.0), 0.398942280401)
32+
assert_almost_equal(activations.gauss_activation(0.0), 1.0)
3333
assert_almost_equal(activations.gauss_activation(-1.0),
3434
activations.gauss_activation(1.0))
3535

@@ -131,6 +131,7 @@ def test_function_set():
131131

132132
assert not s.is_valid('foo')
133133

134+
134135
if __name__ == '__main__':
135136
test_sigmoid()
136137
test_tanh()
@@ -146,3 +147,4 @@ def test_function_set():
146147
test_hat()
147148
test_square()
148149
test_cube()
150+

tests/test_feedforward_network.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ def test_basic():
3131

3232
assert r.values[0] == 0.0
3333

34-
result = r.activate([1.0])
34+
result = r.activate([0.2])
3535

36-
assert r.values[-1] == 1.0
36+
assert r.values[-1] == 0.2
3737
assert_almost_equal(r.values[0], 0.731, 0.001)
3838
assert result[0] == r.values[0]
3939

40-
result = r.activate([2.0])
40+
result = r.activate([0.4])
4141

42-
assert r.values[-1] == 2.0
42+
assert r.values[-1] == 0.4
4343
assert_almost_equal(r.values[0], 0.881, 0.001)
4444
assert result[0] == r.values[0]
4545

@@ -54,7 +54,7 @@ def test_basic():
5454
# 'compatibility_threshold':3.0,
5555
# 'excess_coefficient':1.0,
5656
# 'disjoint_coefficient':1.0,
57-
# 'weight_coefficient':1.0,
57+
# 'compatibility_weight_coefficient':1.0,
5858
# 'conn_add_prob':0.5,
5959
# 'conn_delete_prob':0.05,
6060
# 'node_add_prob':0.1,

tests/test_recurrent_network.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,21 @@ def test_basic():
3939
assert len(r.values[0]) == 2
4040
assert len(r.values[1]) == 2
4141

42-
result = r.activate([1.0])
42+
result = r.activate([0.2])
4343

4444
assert r.active == 1
45-
assert r.values[1][-1] == 1.0
45+
assert r.values[1][-1] == 0.2
4646
assert_almost_equal(r.values[1][0], 0.731, 0.001)
4747
assert result[0] == r.values[1][0]
4848

49-
result = r.activate([2.0])
49+
result = r.activate([0.4])
5050

5151
assert r.active == 0
52-
assert r.values[0][-1] == 2.0
52+
assert r.values[0][-1] == 0.4
5353
assert_almost_equal(r.values[0][0], 0.881, 0.001)
5454
assert result[0] == r.values[0][0]
5555

5656

5757
if __name__ == '__main__':
58+
test_unconnected()
5859
test_basic()
59-
test_unconnected()

0 commit comments

Comments
 (0)