Skip to content

Commit 3753863

Browse files
Removed unused code.
Added input check to RecurrentNetwork.activate. Made FeedForwardNetwork more consistent with RecurrentNetwork. Added tests for FeedForwardNetwork. Added to RecurrentNetwork tests.
1 parent a9aeb21 commit 3753863

File tree

3 files changed

+63
-25
lines changed

3 files changed

+63
-25
lines changed

neat/nn/__init__.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,20 @@
33

44

55
class FeedForwardNetwork(object):
6-
def __init__(self, max_node, inputs, outputs, node_evals):
7-
self.node_evals = node_evals
6+
def __init__(self, inputs, outputs, node_evals):
87
self.input_nodes = inputs
98
self.output_nodes = outputs
9+
self.node_evals = node_evals
1010
self.values = dict((key, 0.0) for key in inputs + outputs)
1111

12-
def serial_activate(self, inputs):
12+
def activate(self, inputs):
1313
if len(self.input_nodes) != len(inputs):
1414
raise Exception("Expected {0} inputs, got {1}".format(len(self.input_nodes), len(inputs)))
1515

1616
for k, v in zip(self.input_nodes, inputs):
1717
self.values[k] = v
1818

19-
for node, agg_func, act_func, bias, response, links in self.node_evals:
19+
for node, act_func, agg_func, bias, response, links in self.node_evals:
2020
#print(node, func, bias, response, links)
2121
node_inputs = []
2222
for i, w in links:
@@ -38,7 +38,6 @@ def create_feed_forward_phenotype(genome, config):
3838
layers = feed_forward_layers(config.genome_config.input_keys, config.genome_config.output_keys, connections)
3939
#print(layers)
4040
node_evals = []
41-
max_used_node = max(max(config.genome_config.input_keys), max(config.genome_config.output_keys))
4241
for layer in layers:
4342
for node in layer:
4443
inputs = []
@@ -48,17 +47,15 @@ def create_feed_forward_phenotype(genome, config):
4847
if cg.output == node and cg.enabled:
4948
inputs.append((cg.input, cg.weight))
5049
node_expr.append("v[%d] * %f" % (cg.input, cg.weight))
51-
max_used_node = max(max_used_node, cg.input)
5250

53-
max_used_node = max(max_used_node, node)
5451
ng = genome.nodes[node]
5552
aggregation_function = config.genome_config.aggregation_function_defs[ng.aggregation]
5653
activation_function = config.genome_config.activation_defs.get(ng.activation)
57-
node_evals.append((node, aggregation_function, activation_function, ng.bias, ng.response, inputs))
54+
node_evals.append((node, activation_function, aggregation_function, ng.bias, ng.response, inputs))
5855

5956
#print(" v[%d] = %s(%f + %f * %s(%s))" % (node, ng.activation, ng.bias, ng.response, ng.aggregation, ", ".join(node_expr)))
6057

61-
return FeedForwardNetwork(max_used_node, config.genome_config.input_keys, config.genome_config.output_keys, node_evals)
58+
return FeedForwardNetwork(config.genome_config.input_keys, config.genome_config.output_keys, node_evals)
6259

6360

6461
class RecurrentNetwork(object):
@@ -83,6 +80,9 @@ def reset(self):
8380
self.active = 0
8481

8582
def activate(self, inputs):
83+
if len(self.input_nodes) != len(inputs):
84+
raise Exception("Expected {0} inputs, got {1}".format(len(self.input_nodes), len(inputs)))
85+
8686
ivalues = self.values[self.active]
8787
ovalues = self.values[1 - self.active]
8888
self.active = 1 - self.active

tests/test_feedforward.py renamed to tests/test_feedforward_network.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,49 @@
1-
import random
2-
from neat import nn
1+
from neat import activations
2+
from neat.nn import FeedForwardNetwork
33

44

55
def assert_almost_equal(x, y, tol):
66
assert abs(x - y) < tol, "{!r} !~= {!r}".format(x, y)
77

88

9+
def test_unconnected():
10+
# Unconnected network with no inputs and one output neuron.
11+
node_evals = [(0, activations.sigmoid_activation, sum, 0.0, 1.0, [])]
12+
r = FeedForwardNetwork([], [0], node_evals)
13+
14+
assert r.values[0] == 0.0
15+
16+
result = r.activate([])
17+
18+
assert_almost_equal(r.values[0], 0.5, 0.001)
19+
assert result[0] == r.values[0]
20+
21+
result = r.activate([])
22+
23+
assert_almost_equal(r.values[0], 0.5, 0.001)
24+
assert result[0] == r.values[0]
25+
26+
27+
def test_basic():
28+
# Very simple network with one connection of weight one to a single sigmoid output node.
29+
node_evals = [(0, activations.sigmoid_activation, sum, 0.0, 1.0, [(-1, 1.0)])]
30+
r = FeedForwardNetwork([-1], [0], node_evals)
31+
32+
assert r.values[0] == 0.0
33+
34+
result = r.activate([1.0])
35+
36+
assert r.values[-1] == 1.0
37+
assert_almost_equal(r.values[0], 0.731, 0.001)
38+
assert result[0] == r.values[0]
39+
40+
result = r.activate([2.0])
41+
42+
assert r.values[-1] == 2.0
43+
assert_almost_equal(r.values[0], 0.881, 0.001)
44+
assert result[0] == r.values[0]
45+
46+
947
# TODO: Update this test for the current implementation.
1048
# def test_simple_nohidden():
1149
# config_params = {
@@ -72,10 +110,5 @@ def assert_almost_equal(x, y, tol):
72110

73111

74112
if __name__ == '__main__':
75-
test_creates_cycle()
76-
test_required_for_output()
77-
test_fuzz_required()
78-
test_feed_forward_layers()
79-
test_fuzz_feed_forward_layers()
80-
#test_simple_nohidden()
81-
#test_simple_hidden()
113+
test_unconnected()
114+
test_basic()

tests/test_recurrent_network.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from neat import activations, nn
1+
from neat import activations
2+
from neat.nn import RecurrentNetwork
23

34

45
def assert_almost_equal(x, y, tol):
@@ -8,45 +9,49 @@ def assert_almost_equal(x, y, tol):
89
def test_unconnected():
910
# Unconnected network with no inputs and one output neuron.
1011
node_evals = [(0, activations.sigmoid_activation, sum, 0.0, 1.0, [])]
11-
r = nn.RecurrentNetwork([], [0], node_evals)
12+
r = RecurrentNetwork([], [0], node_evals)
1213

1314
assert r.active == 0
1415
assert len(r.values) == 2
1516
assert len(r.values[0]) == 1
1617
assert len(r.values[1]) == 1
1718

18-
r.activate([1.0])
19+
result = r.activate([])
1920

2021
assert r.active == 1
2122
assert_almost_equal(r.values[1][0], 0.5, 0.001)
23+
assert result[0] == r.values[1][0]
2224

23-
r.activate([2.0])
25+
result = r.activate([])
2426

2527
assert r.active == 0
2628
assert_almost_equal(r.values[0][0], 0.5, 0.001)
29+
assert result[0] == r.values[0][0]
2730

2831

2932
def test_basic():
3033
# Very simple network with one connection of weight one to a single sigmoid output node.
3134
node_evals = [(0, activations.sigmoid_activation, sum, 0.0, 1.0, [(-1, 1.0)])]
32-
r = nn.RecurrentNetwork([-1], [0], node_evals)
35+
r = RecurrentNetwork([-1], [0], node_evals)
3336

3437
assert r.active == 0
3538
assert len(r.values) == 2
3639
assert len(r.values[0]) == 2
3740
assert len(r.values[1]) == 2
3841

39-
r.activate([1.0])
42+
result = r.activate([1.0])
4043

4144
assert r.active == 1
4245
assert r.values[1][-1] == 1.0
4346
assert_almost_equal(r.values[1][0], 0.731, 0.001)
47+
assert result[0] == r.values[1][0]
4448

45-
r.activate([2.0])
49+
result = r.activate([2.0])
4650

4751
assert r.active == 0
4852
assert r.values[0][-1] == 2.0
4953
assert_almost_equal(r.values[0][0], 0.881, 0.001)
54+
assert result[0] == r.values[0][0]
5055

5156

5257
if __name__ == '__main__':

0 commit comments

Comments
 (0)