Skip to content

Commit d93c1ea

Browse files
Moved graph algorithms into a separate module.
Validate activation function signatures.
1 parent a0c8d6f commit d93c1ea

File tree

5 files changed

+265
-243
lines changed

5 files changed

+265
-243
lines changed

neat/activations.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import math
23

34

@@ -22,7 +23,7 @@ def gauss_activation(z):
2223

2324

2425
def relu_activation(z):
25-
return z if z > 0.0 else 0
26+
return z if z > 0.0 else 0.0
2627

2728

2829
def identity_activation(z):
@@ -70,6 +71,15 @@ class InvalidActivationFunction(Exception):
7071
pass
7172

7273

74+
def validate_activation(function):
75+
if not inspect.isfunction(function):
76+
raise InvalidActivationFunction("A function object is required.")
77+
78+
args = inspect.getargspec(function.__call__)
79+
if len(args[0]) != 1:
80+
raise InvalidActivationFunction("A single-argument function is required.")
81+
82+
7383
class ActivationFunctionSet(object):
7484
def __init__(self):
7585
self.functions = {'sigmoid': sigmoid_activation,
@@ -88,7 +98,7 @@ def __init__(self):
8898
'cube': cube_activation}
8999

90100
def add(self, config_name, function):
91-
# TODO: Verify that the given function has the correct signature.
101+
validate_activation(function)
92102
self.functions[config_name] = function
93103

94104
def get(self, config_name):
@@ -99,7 +109,4 @@ def get(self, config_name):
99109
return f
100110

101111
def is_valid(self, config_name):
102-
# TODO: Verify that the given function has the correct signature.
103112
return config_name in self.functions
104-
105-

neat/graphs.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
""" Directed graph algorithm implementations. """
2+
3+
4+
def creates_cycle(connections, test):
5+
"""
6+
Returns true if the addition of the "test" connection would create a cycle,
7+
assuming that no cycle already exists in the graph represented by "connections".
8+
"""
9+
i, o = test
10+
if i == o:
11+
return True
12+
13+
visited = set([o])
14+
while True:
15+
num_added = 0
16+
for a, b in connections:
17+
if a in visited and b not in visited:
18+
if b == i:
19+
return True
20+
21+
visited.add(b)
22+
num_added += 1
23+
24+
if num_added == 0:
25+
return False
26+
27+
28+
def required_for_output(inputs, outputs, connections):
29+
'''
30+
Collect the nodes whose state is required to compute the final network output(s).
31+
:param inputs: list of the input identifiers
32+
:param outputs: list of the output node identifiers
33+
:param connections: list of (input, output) connections in the network.
34+
NOTE: It is assumed that the input identifier set and the node identifier set are disjoint.
35+
By convention, the output node ids are always the same as the output index.
36+
37+
Returns a list of layers, with each layer consisting of a set of identifiers.
38+
'''
39+
40+
required = set(outputs)
41+
S = set(outputs)
42+
while 1:
43+
# Find nodes not in S whose output is consumed by a node in S.
44+
T = set(a for (a, b) in connections if b in S and a not in S)
45+
46+
if not T:
47+
break
48+
49+
layer_nodes = set(x for x in T if x not in inputs)
50+
if not layer_nodes:
51+
break
52+
53+
required = required.union(layer_nodes)
54+
S = S.union(T)
55+
56+
return required
57+
58+
59+
def feed_forward_layers(inputs, outputs, connections):
60+
'''
61+
Collect the layers whose members can be evaluated in parallel in a feed-forward network.
62+
:param inputs: list of the network input nodes
63+
:param outputs: list of the output node identifiers
64+
:param connections: list of (input, output) connections in the network.
65+
66+
Returns a list of layers, with each layer consisting of a set of node identifiers.
67+
Note that the returned layers do not contain nodes whose output is ultimately
68+
never used to compute the final network output.
69+
'''
70+
71+
required = required_for_output(inputs, outputs, connections)
72+
73+
layers = []
74+
S = set(inputs)
75+
while 1:
76+
# Find candidate nodes C for the next layer. These nodes should connect
77+
# a node in S to a node not in S.
78+
C = set(b for (a, b) in connections if a in S and b not in S)
79+
# Keep only the used nodes whose entire input set is contained in S.
80+
T = set()
81+
for n in C:
82+
if n in required and all(a in S for (a, b) in connections if b == n):
83+
T.add(n)
84+
85+
if not T:
86+
break
87+
88+
layers.append(T)
89+
S = S.union(T)
90+
91+
return layers
92+
93+

neat/nn/__init__.py

Lines changed: 1 addition & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,6 @@
1+
from neat.graphs import creates_cycle, required_for_output, feed_forward_layers
12
from neat.six_util import iterkeys, itervalues, iteritems
23

3-
# TODO: All this directed graph logic should be in a core module.
4-
5-
6-
def creates_cycle(connections, test):
7-
"""
8-
Returns true if the addition of the "test" connection would create a cycle,
9-
assuming that no cycle already exists in the graph represented by "connections".
10-
"""
11-
i, o = test
12-
if i == o:
13-
return True
14-
15-
visited = set([o])
16-
while True:
17-
num_added = 0
18-
for a, b in connections:
19-
if a in visited and b not in visited:
20-
if b == i:
21-
return True
22-
23-
visited.add(b)
24-
num_added += 1
25-
26-
if num_added == 0:
27-
return False
28-
29-
30-
def required_for_output(inputs, outputs, connections):
31-
'''
32-
Collect the nodes whose state is required to compute the final network output(s).
33-
:param inputs: list of the input identifiers
34-
:param outputs: list of the output node identifiers
35-
:param connections: list of (input, output) connections in the network.
36-
NOTE: It is assumed that the input identifier set and the node identifier set are disjoint.
37-
By convention, the output node ids are always the same as the output index.
38-
39-
Returns a list of layers, with each layer consisting of a set of identifiers.
40-
'''
41-
42-
required = set(outputs)
43-
S = set(outputs)
44-
while 1:
45-
# Find nodes not in S whose output is consumed by a node in S.
46-
T = set(a for (a, b) in connections if b in S and a not in S)
47-
48-
if not T:
49-
break
50-
51-
layer_nodes = set(x for x in T if x not in inputs)
52-
if not layer_nodes:
53-
break
54-
55-
required = required.union(layer_nodes)
56-
S = S.union(T)
57-
58-
return required
59-
60-
61-
def feed_forward_layers(inputs, outputs, connections):
62-
'''
63-
Collect the layers whose members can be evaluated in parallel in a feed-forward network.
64-
:param inputs: list of the network input nodes
65-
:param outputs: list of the output node identifiers
66-
:param connections: list of (input, output) connections in the network.
67-
68-
Returns a list of layers, with each layer consisting of a set of node identifiers.
69-
Note that the returned layers do not contain nodes whose output is ultimately
70-
never used to compute the final network output.
71-
'''
72-
73-
required = required_for_output(inputs, outputs, connections)
74-
75-
layers = []
76-
S = set(inputs)
77-
while 1:
78-
# Find candidate nodes C for the next layer. These nodes should connect
79-
# a node in S to a node not in S.
80-
C = set(b for (a, b) in connections if a in S and b not in S)
81-
# Keep only the used nodes whose entire input set is contained in S.
82-
T = set()
83-
for n in C:
84-
if n in required and all(a in S for (a, b) in connections if b == n):
85-
T.add(n)
86-
87-
if not T:
88-
break
89-
90-
layers.append(T)
91-
S = S.union(T)
92-
93-
return layers
94-
954

965
class FeedForwardNetwork(object):
976
def __init__(self, max_node, inputs, outputs, node_evals):

tests/test_feedforward.py

Lines changed: 0 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -1,157 +1,11 @@
11
import random
2-
32
from neat import nn
43

54

65
def assert_almost_equal(x, y, tol):
76
assert abs(x - y) < tol, "{!r} !~= {!r}".format(x, y)
87

98

10-
def test_creates_cycle():
11-
assert nn.creates_cycle([(0, 1), (1, 2), (2, 3)], (0, 0))
12-
13-
assert nn.creates_cycle([(0, 1), (1, 2), (2, 3)], (1, 0))
14-
assert not nn.creates_cycle([(0, 1), (1, 2), (2, 3)], (0, 1))
15-
16-
assert nn.creates_cycle([(0, 1), (1, 2), (2, 3)], (2, 0))
17-
assert not nn.creates_cycle([(0, 1), (1, 2), (2, 3)], (0, 2))
18-
19-
assert nn.creates_cycle([(0, 1), (1, 2), (2, 3)], (3, 0))
20-
assert not nn.creates_cycle([(0, 1), (1, 2), (2, 3)], (0, 3))
21-
22-
assert nn.creates_cycle([(0, 2), (1, 3), (2, 3), (4, 2)], (3, 4))
23-
assert not nn.creates_cycle([(0, 2), (1, 3), (2, 3), (4, 2)], (4, 3))
24-
25-
26-
def test_required_for_output():
27-
inputs = [0, 1]
28-
outputs = [2]
29-
connections = [(0, 2), (1, 2)]
30-
required = nn.required_for_output(inputs, outputs, connections)
31-
assert {2} == required
32-
33-
inputs = [0, 1]
34-
outputs = [2]
35-
connections = [(0, 3), (1, 4), (3, 2), (4, 2)]
36-
required = nn.required_for_output(inputs, outputs, connections)
37-
assert {2, 3, 4} == required
38-
39-
inputs = [0, 1]
40-
outputs = [3]
41-
connections = [(0, 2), (1, 2), (2, 3)]
42-
required = nn.required_for_output(inputs, outputs, connections)
43-
assert {2, 3} == required
44-
45-
inputs = [0, 1]
46-
outputs = [4]
47-
connections = [(0, 2), (1, 2), (1, 3), (2, 3), (2, 4), (3, 4)]
48-
required = nn.required_for_output(inputs, outputs, connections)
49-
assert {2, 3, 4} == required
50-
51-
inputs = [0, 1]
52-
outputs = [4]
53-
connections = [(0, 2), (1, 3), (2, 3), (3, 4), (4, 2)]
54-
required = nn.required_for_output(inputs, outputs, connections)
55-
assert {2, 3, 4} == required
56-
57-
inputs = [0, 1]
58-
outputs = [4]
59-
connections = [(0, 2), (1, 2), (1, 3), (2, 3), (2, 4), (3, 4), (2, 5)]
60-
required = nn.required_for_output(inputs, outputs, connections)
61-
assert {2, 3, 4} == required
62-
63-
64-
def test_fuzz_required():
65-
for _ in range(1000):
66-
n_hidden = random.randint(10, 100)
67-
n_in = random.randint(1, 10)
68-
n_out = random.randint(1, 10)
69-
nodes = list(set(random.randint(0, 1000) for _ in range(n_in + n_out + n_hidden)))
70-
random.shuffle(nodes)
71-
72-
inputs = nodes[:n_in]
73-
outputs = nodes[n_in:n_in + n_out]
74-
connections = []
75-
for _ in range(n_hidden * 2):
76-
a = random.choice(nodes)
77-
b = random.choice(nodes)
78-
if a == b:
79-
continue
80-
if a in inputs and b in inputs:
81-
continue
82-
if a in outputs and b in outputs:
83-
continue
84-
connections.append((a, b))
85-
86-
required = nn.required_for_output(inputs, outputs, connections)
87-
for o in outputs:
88-
assert o in required
89-
90-
91-
def test_feed_forward_layers():
92-
inputs = [0, 1]
93-
outputs = [2]
94-
connections = [(0, 2), (1, 2)]
95-
layers = nn.feed_forward_layers(inputs, outputs, connections)
96-
assert [{2}] == layers
97-
98-
inputs = [0, 1]
99-
outputs = [3]
100-
connections = [(0, 2), (1, 2), (2, 3)]
101-
layers = nn.feed_forward_layers(inputs, outputs, connections)
102-
assert [{2}, {3}] == layers
103-
104-
inputs = [0, 1]
105-
outputs = [4]
106-
connections = [(0, 2), (1, 2), (1, 3), (2, 3), (2, 4), (3, 4)]
107-
layers = nn.feed_forward_layers(inputs, outputs, connections)
108-
assert [{2}, {3}, {4}] == layers
109-
110-
inputs = [0, 1, 2, 3]
111-
outputs = [11, 12, 13]
112-
connections = [(0, 4), (1, 4), (1, 5), (2, 5), (2, 6), (3, 6), (3, 7),
113-
(4, 8), (5, 8), (5, 9), (5, 10), (6, 10), (6, 7),
114-
(8, 11), (8, 12), (8, 9), (9, 10), (7, 10),
115-
(10, 12), (10, 13)]
116-
layers = nn.feed_forward_layers(inputs, outputs, connections)
117-
assert [{4, 5, 6}, {8, 7}, {9, 11}, {10}, {12, 13}] == layers
118-
119-
inputs = [0, 1, 2, 3]
120-
outputs = [11, 12, 13]
121-
connections = [(0, 4), (1, 4), (1, 5), (2, 5), (2, 6), (3, 6), (3, 7),
122-
(4, 8), (5, 8), (5, 9), (5, 10), (6, 10), (6, 7),
123-
(8, 11), (8, 12), (8, 9), (9, 10), (7, 10),
124-
(10, 12), (10, 13),
125-
(3, 14), (14, 15), (5, 16), (10, 16)]
126-
layers = nn.feed_forward_layers(inputs, outputs, connections)
127-
assert [{4, 5, 6}, {8, 7}, {9, 11}, {10}, {12, 13}] == layers
128-
129-
130-
def test_fuzz_feed_forward_layers():
131-
for _ in range(1000):
132-
n_hidden = random.randint(10, 100)
133-
n_in = random.randint(1, 10)
134-
n_out = random.randint(1, 10)
135-
nodes = list(set(random.randint(0, 1000) for _ in range(n_in + n_out + n_hidden)))
136-
random.shuffle(nodes)
137-
138-
inputs = nodes[:n_in]
139-
outputs = nodes[n_in:n_in + n_out]
140-
connections = []
141-
for _ in range(n_hidden * 2):
142-
a = random.choice(nodes)
143-
b = random.choice(nodes)
144-
if a == b:
145-
continue
146-
if a in inputs and b in inputs:
147-
continue
148-
if a in outputs and b in outputs:
149-
continue
150-
connections.append((a, b))
151-
152-
nn.feed_forward_layers(inputs, outputs, connections)
153-
154-
1559
# TODO: Update this test for the current implementation.
15610
# def test_simple_nohidden():
15711
# config_params = {

0 commit comments

Comments
 (0)