Skip to content

Commit 8b44306

Browse files
Fixed old FFGenome tests to use DefaultGenome.
Added convenience add node/connection methods to DefaultGenome. Moved gene validation into validate() methods.
1 parent 7149b38 commit 8b44306

File tree

3 files changed

+49
-30
lines changed

3 files changed

+49
-30
lines changed

neat/genes.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,22 @@ class NodeGene(object):
77
""" Encodes parameters for a single artificial neuron. """
88

99
def __init__(self, key, bias, response, aggregation, activation):
10-
# TODO: Move these asserts into an external validation mechanism that can be omitted at runtime if desired.
11-
# Maybe this class should implement a validate(config) method that can optionally be called
12-
# by the NEAT framework?
13-
# TODO: Validate aggregation and activation against current configuration.
14-
assert type(bias) is float
15-
assert type(response) is float
16-
assert type(aggregation) is str
17-
assert type(activation) is str
18-
1910
self.key = key
2011
self.bias = bias
2112
self.response = response
2213
self.aggregation = aggregation
2314
self.activation = activation
2415

16+
# TODO: Implement an external validation mechanism that can be omitted at runtime if desired.
17+
self.validate()
18+
19+
def validate(self):
20+
# TODO: Validate aggregation and activation against current configuration.
21+
assert type(self.bias) is float
22+
assert type(self.response) is float
23+
assert type(self.aggregation) is str
24+
assert type(self.activation) is str
25+
2526
def __str__(self):
2627
return 'NodeGene(key= {0}, bias={1}, response={2}, aggregation={3}, activation={4})'.format(
2728
self.key, self.bias, self.response, self.aggregation, self.activation)
@@ -57,12 +58,6 @@ def mutate(self, config):
5758

5859
class ConnectionGene(object):
5960
def __init__(self, input_id, output_id, weight, enabled):
60-
# TODO: Move these asserts into an external validation mechanism that can be omitted at runtime if desired.
61-
assert type(input_id) is int
62-
assert type(output_id) is int
63-
assert type(weight) is float
64-
assert type(enabled) is bool
65-
6661
self.key = (input_id, output_id)
6762
self.input = input_id
6863
self.output = output_id
@@ -72,6 +67,15 @@ def __init__(self, input_id, output_id, weight, enabled):
7267
# provide a similar effect depending on the weight range and mutation rate.
7368
self.enabled = enabled
7469

70+
# TODO: Implement an external validation mechanism that can be omitted at runtime if desired.
71+
self.validate()
72+
73+
def validate(self):
74+
assert type(self.key) is tuple
75+
assert len(self.key) == 2
76+
assert type(self.weight) is float
77+
assert type(self.enabled) is bool
78+
7579
# TODO: Factor out mutation into a separate class.
7680
def mutate(self, config):
7781
self.weight = config.mutate_weight(self.weight)

neat/genome.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ def __init__(self, key, config):
4242
self.fitness = None
4343
self.cross_fitness = None
4444

45+
def add_node(self, key, bias, response, aggregation, activation):
46+
# TODO: Add validation of this node addition.
47+
self.nodes[key] = NodeGene(key, bias, response, aggregation, activation)
48+
49+
def add_connection(self, input_key, output_key, weight, enabled):
50+
# TODO: Add validation of this connection addition.
51+
self.connections[input_key, output_key] = ConnectionGene(input_key, output_key, weight, enabled)
52+
4553
def mutate(self, config):
4654
""" Mutates this genome. """
4755

tests/test_feedforward.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import random
22

33
from neat import nn
4-
from neat.genome import FFGenome
4+
from neat.config import Config
5+
from neat.genome import DefaultGenome
56
from neat.genes import NodeGene, ConnectionGene
67

78

@@ -157,11 +158,14 @@ def test_fuzz_feed_forward_layers():
157158

158159

159160
def test_simple_nohidden():
160-
g = FFGenome(0, [0, 1], [2])
161-
g.nodes[2] = NodeGene(2, 0.0, 1.0, 'sum', 'tanh')
162-
g.connections[(0, 2)] = ConnectionGene(0, 2, 1.0, True)
163-
g.connections[(1, 2)] = ConnectionGene(1, 2, -1.0, True)
164-
net = nn.create_feed_forward_phenotype(g)
161+
config = Config()
162+
config.set_input_output_sizes(2, 1)
163+
g = DefaultGenome(0, config)
164+
g.add_node(0, 0.0, 1.0, 'sum', 'tanh')
165+
g.add_connection(-1, 0, 1.0, True)
166+
g.add_connection(-2, 0, -1.0, True)
167+
168+
net = nn.create_feed_forward_phenotype(g, config)
165169

166170
v00 = net.serial_activate([0.0, 0.0])
167171
assert_almost_equal(v00[0], 0.0, 1e-3)
@@ -177,15 +181,18 @@ def test_simple_nohidden():
177181

178182

179183
def test_simple_hidden():
180-
g = FFGenome(0, [0, 1], [2])
181-
g.nodes[2] = NodeGene(2, 0.0, 1.0, 'sum', 'identity')
182-
g.nodes[3] = NodeGene(3, -0.5, 5.0, 'sum', 'sigmoid')
183-
g.nodes[4] = NodeGene(3, -1.5, 5.0, 'sum', 'sigmoid')
184-
g.connections[(0, 3)] = ConnectionGene(0, 3, 1.0, True)
185-
g.connections[(1, 4)] = ConnectionGene(1, 4, 1.0, True)
186-
g.connections[(3, 2)] = ConnectionGene(3, 2, 1.0, True)
187-
g.connections[(4, 2)] = ConnectionGene(4, 2, -1.0, True)
188-
net = nn.create_feed_forward_phenotype(g)
184+
config = Config()
185+
config.set_input_output_sizes(2, 1)
186+
g = DefaultGenome(0, config)
187+
188+
g.add_node(0, 0.0, 1.0, 'sum', 'identity')
189+
g.add_node(1, -0.5, 5.0, 'sum', 'sigmoid')
190+
g.add_node(2, -1.5, 5.0, 'sum', 'sigmoid')
191+
g.add_connection(-1, 1, 1.0, True)
192+
g.add_connection(-2, 2, 1.0, True)
193+
g.add_connection(1, 0, 1.0, True)
194+
g.add_connection(2, 0, -1.0, True)
195+
net = nn.create_feed_forward_phenotype(g, config)
189196

190197
v00 = net.serial_activate([0.0, 0.0])
191198
assert_almost_equal(v00[0], 0.195115, 1e-3)

0 commit comments

Comments
 (0)