Skip to content

Commit 7830482

Browse files
Minor refactoring.
Remove unused code. Added explanatory comments.
1 parent e8b7b55 commit 7830482

File tree

2 files changed

+38
-53
lines changed

2 files changed

+38
-53
lines changed

neat/genes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# TODO: There is probably a lot of room for simplification of these classes using metaprogramming.
55
# TODO: Evaluate using __slots__ for performance/memory usage improvement.
66

7+
78
class BaseGene(object):
89
def __init__(self, key):
910
self.key = key
@@ -41,6 +42,8 @@ def copy(self):
4142
for a in self.__gene_attributes__:
4243
setattr(new_gene, a.name, getattr(self, a.name))
4344

45+
return new_gene
46+
4447
def crossover(self, gene2):
4548
""" Creates a new gene randomly inheriting attributes from its parents."""
4649
assert self.key == gene2.key

neat/genome.py

Lines changed: 35 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from neat.nn import creates_cycle
77

88
from math import fabs
9-
from random import choice, gauss, random, shuffle
9+
from random import choice, random, shuffle
10+
1011

1112
class DefaultGenomeConfig(object):
1213
__params = [ConfigParameter('num_inputs', int),
@@ -45,7 +46,10 @@ def __init__(self, params):
4546
for p in self.__params:
4647
setattr(self, p.name, p.interpret(params))
4748

48-
self.build_keys()
49+
# By convention, input pins have negative keys, and the output
50+
# pins have keys 0,1,...
51+
self.input_keys = [-i - 1 for i in range(self.num_inputs)]
52+
self.output_keys = [i for i in range(self.num_outputs)]
4953

5054
self.connection_fraction = None
5155

@@ -77,10 +81,6 @@ def save(self, f):
7781
# assert self.initial_connection in self.allowed_connectivity
7882
write_pretty_params(f, self, self.__params)
7983

80-
def build_keys(self):
81-
self.input_keys = [-i - 1 for i in range(self.num_inputs)]
82-
self.output_keys = [i for i in range(self.num_outputs)]
83-
8484

8585
class DefaultGenome(object):
8686
"""
@@ -127,15 +127,6 @@ def __init__(self, key, config):
127127
self.fitness = None
128128
self.cross_fitness = None
129129

130-
def add_node(self, key, bias, response, aggregation, activation):
131-
# TODO: Add validation of this node addition.
132-
self.nodes[key] = DefaultNodeGene(key, bias, response, aggregation, activation)
133-
134-
def add_connection(self, input_key, output_key, weight, enabled):
135-
# TODO: Add validation of this connection addition.
136-
key = (input_key, output_key)
137-
self.connections[key] = DefaultConnectionGene(key, weight, enabled)
138-
139130
def mutate(self, config):
140131
""" Mutates this genome. """
141132

@@ -217,31 +208,26 @@ def mutate_add_node(self, config):
217208
# Choose a random connection to split
218209
conn_to_split = choice(list(self.connections.values()))
219210
new_node_id = self.get_new_hidden_id()
220-
act_func = choice(config.activation_options)
221211
ng = self.create_node(config, new_node_id)
222212
self.nodes[new_node_id] = ng
223-
new_conn1, new_conn2 = conn_to_split.split(new_node_id)
224-
225-
# TODO: Make sure this logic is retained in the appropriate place.
226-
# class ConnectionGene(object):
227-
# def split(self, node_id):
228-
# """
229-
# Disable this connection and create two new connections joining its nodes via
230-
# the given node. The new node+connections have roughly the same behavior as
231-
# the original connection (depending on the activation function of the new node).
232-
# """
233-
# self.enabled = False
234-
# new_conn1 = ConnectionGene(self.input, node_id, 1.0, True)
235-
# new_conn2 = ConnectionGene(node_id, self.output, self.weight, True)
236-
#
237-
# return new_conn1, new_conn2
238-
239213

214+
# Disable this connection and create two new connections joining its nodes via
215+
# the given node. The new node+connections have roughly the same behavior as
216+
# the original connection (depending on the activation function of the new node).
217+
conn_to_split.enabled = False
240218

219+
i, o = conn_to_split.key
220+
self.add_connection(config, i, new_node_id, 1.0, True)
221+
self.add_connection(config, new_node_id, o, conn_to_split.weight, True)
241222

242-
self.connections[new_conn1.key] = new_conn1
243-
self.connections[new_conn2.key] = new_conn2
244-
return ng, conn_to_split # the return is only used in genome_feedforward
223+
def add_connection(self, config, input_key, output_key, weight, enabled):
224+
# TODO: Add validation of this connection addition.
225+
key = (input_key, output_key)
226+
connection = DefaultConnectionGene(key)
227+
connection.init_attributes(config)
228+
connection.weight = weight
229+
connection.enabled = enabled
230+
self.connections[key] = connection
245231

246232
def mutate_add_connection(self, config):
247233
'''
@@ -263,10 +249,7 @@ def mutate_add_connection(self, config):
263249
if config.feed_forward and creates_cycle(list(iterkeys(self.connections)), key):
264250
return
265251

266-
# TODO: factor out new connection creation based on config
267-
weight = gauss(0, config.weight_stdev)
268-
enabled = choice([False, True])
269-
cg = DefaultConnectionGene(in_node, out_node, weight, enabled)
252+
cg = self.create_connection(config, in_node, out_node)
270253
self.connections[cg.key] = cg
271254

272255
def mutate_delete_node(self, config):
@@ -321,7 +304,6 @@ def distance(self, other, config):
321304
activation_diff = 0
322305
num_common = 0
323306

324-
325307
# TODO: Factor out the gene-specific distance components into the gene classes.
326308

327309
for k2 in node_genes2.keys():
@@ -431,8 +413,12 @@ def create_node(config, node_id):
431413
node = DefaultNodeGene(node_id)
432414
node.init_attributes(config)
433415
return node
434-
# return NodeGene(node_id, genome_config.new_bias(), genome_config.new_response(),
435-
# genome_config.new_aggregation(), genome_config.new_activation())
416+
417+
@staticmethod
418+
def create_connection(config, input_id, output_id):
419+
connection = DefaultConnectionGene((input_id, output_id))
420+
connection.init_attributes(config)
421+
return connection
436422

437423
@classmethod
438424
def create_unconnected(cls, config, key):
@@ -451,9 +437,8 @@ def connect_fs_neat(self, config):
451437
# TODO: Factor out the gene creation.
452438
input_id = choice(self.inputs.keys())
453439
for output_id in list(self.hidden.keys()) + list(self.outputs.keys()):
454-
weight = gauss(0, config.weight_stdev)
455-
cg = DefaultConnectionGene(input_id, output_id, weight, True)
456-
self.connections[cg.key] = cg
440+
connection = self.create_connection(config, input_id, output_id)
441+
self.connections[connection.key] = connection
457442

458443
def compute_full_connections(self, config):
459444
""" Compute connections for a fully-connected feed-forward genome (each input connected to all nodes). """
@@ -466,18 +451,15 @@ def compute_full_connections(self, config):
466451

467452
def connect_full(self, config):
468453
""" Create a fully-connected genome. """
469-
# TODO: Factor out the gene creation.
470454
for input_id, output_id in self.compute_full_connections(config):
471-
weight = gauss(0, config.weight_stdev)
472-
cg = DefaultConnectionGene(input_id, output_id, weight, True)
473-
self.connections[cg.key] = cg
455+
connection = self.create_connection(config, input_id, output_id)
456+
self.connections[connection.key] = connection
474457

475458
def connect_partial(self, config):
476459
assert 0 <= config.connection_fraction <= 1
477460
all_connections = self.compute_full_connections(config)
478461
shuffle(all_connections)
479462
num_to_add = int(round(len(all_connections) * config.connection_fraction))
480-
for key in all_connections[:num_to_add]:
481-
gene = DefaultConnectionGene(key)
482-
gene.init_attributes(config)
483-
self.connections[key] = gene
463+
for input_id, output_id in all_connections[:num_to_add]:
464+
connection = self.create_connection(config, input_id, output_id)
465+
self.connections[connection.key] = connection

0 commit comments

Comments
 (0)