6
6
from neat .nn import creates_cycle
7
7
8
8
from math import fabs
9
- from random import choice , gauss , random , shuffle
9
+ from random import choice , random , shuffle
10
+
10
11
11
12
class DefaultGenomeConfig (object ):
12
13
__params = [ConfigParameter ('num_inputs' , int ),
@@ -45,7 +46,10 @@ def __init__(self, params):
45
46
for p in self .__params :
46
47
setattr (self , p .name , p .interpret (params ))
47
48
48
- self .build_keys ()
49
+ # By convention, input pins have negative keys, and the output
50
+ # pins have keys 0,1,...
51
+ self .input_keys = [- i - 1 for i in range (self .num_inputs )]
52
+ self .output_keys = [i for i in range (self .num_outputs )]
49
53
50
54
self .connection_fraction = None
51
55
@@ -77,10 +81,6 @@ def save(self, f):
77
81
# assert self.initial_connection in self.allowed_connectivity
78
82
write_pretty_params (f , self , self .__params )
79
83
80
- def build_keys (self ):
81
- self .input_keys = [- i - 1 for i in range (self .num_inputs )]
82
- self .output_keys = [i for i in range (self .num_outputs )]
83
-
84
84
85
85
class DefaultGenome (object ):
86
86
"""
@@ -127,15 +127,6 @@ def __init__(self, key, config):
127
127
self .fitness = None
128
128
self .cross_fitness = None
129
129
130
- def add_node (self , key , bias , response , aggregation , activation ):
131
- # TODO: Add validation of this node addition.
132
- self .nodes [key ] = DefaultNodeGene (key , bias , response , aggregation , activation )
133
-
134
- def add_connection (self , input_key , output_key , weight , enabled ):
135
- # TODO: Add validation of this connection addition.
136
- key = (input_key , output_key )
137
- self .connections [key ] = DefaultConnectionGene (key , weight , enabled )
138
-
139
130
def mutate (self , config ):
140
131
""" Mutates this genome. """
141
132
@@ -217,31 +208,26 @@ def mutate_add_node(self, config):
217
208
# Choose a random connection to split
218
209
conn_to_split = choice (list (self .connections .values ()))
219
210
new_node_id = self .get_new_hidden_id ()
220
- act_func = choice (config .activation_options )
221
211
ng = self .create_node (config , new_node_id )
222
212
self .nodes [new_node_id ] = ng
223
- new_conn1 , new_conn2 = conn_to_split .split (new_node_id )
224
-
225
- # TODO: Make sure this logic is retained in the appropriate place.
226
- # class ConnectionGene(object):
227
- # def split(self, node_id):
228
- # """
229
- # Disable this connection and create two new connections joining its nodes via
230
- # the given node. The new node+connections have roughly the same behavior as
231
- # the original connection (depending on the activation function of the new node).
232
- # """
233
- # self.enabled = False
234
- # new_conn1 = ConnectionGene(self.input, node_id, 1.0, True)
235
- # new_conn2 = ConnectionGene(node_id, self.output, self.weight, True)
236
- #
237
- # return new_conn1, new_conn2
238
-
239
213
214
+ # Disable this connection and create two new connections joining its nodes via
215
+ # the given node. The new node+connections have roughly the same behavior as
216
+ # the original connection (depending on the activation function of the new node).
217
+ conn_to_split .enabled = False
240
218
219
+ i , o = conn_to_split .key
220
+ self .add_connection (config , i , new_node_id , 1.0 , True )
221
+ self .add_connection (config , new_node_id , o , conn_to_split .weight , True )
241
222
242
- self .connections [new_conn1 .key ] = new_conn1
243
- self .connections [new_conn2 .key ] = new_conn2
244
- return ng , conn_to_split # the return is only used in genome_feedforward
223
+ def add_connection (self , config , input_key , output_key , weight , enabled ):
224
+ # TODO: Add validation of this connection addition.
225
+ key = (input_key , output_key )
226
+ connection = DefaultConnectionGene (key )
227
+ connection .init_attributes (config )
228
+ connection .weight = weight
229
+ connection .enabled = enabled
230
+ self .connections [key ] = connection
245
231
246
232
def mutate_add_connection (self , config ):
247
233
'''
@@ -263,10 +249,7 @@ def mutate_add_connection(self, config):
263
249
if config .feed_forward and creates_cycle (list (iterkeys (self .connections )), key ):
264
250
return
265
251
266
- # TODO: factor out new connection creation based on config
267
- weight = gauss (0 , config .weight_stdev )
268
- enabled = choice ([False , True ])
269
- cg = DefaultConnectionGene (in_node , out_node , weight , enabled )
252
+ cg = self .create_connection (config , in_node , out_node )
270
253
self .connections [cg .key ] = cg
271
254
272
255
def mutate_delete_node (self , config ):
@@ -321,7 +304,6 @@ def distance(self, other, config):
321
304
activation_diff = 0
322
305
num_common = 0
323
306
324
-
325
307
# TODO: Factor out the gene-specific distance components into the gene classes.
326
308
327
309
for k2 in node_genes2 .keys ():
@@ -431,8 +413,12 @@ def create_node(config, node_id):
431
413
node = DefaultNodeGene (node_id )
432
414
node .init_attributes (config )
433
415
return node
434
- # return NodeGene(node_id, genome_config.new_bias(), genome_config.new_response(),
435
- # genome_config.new_aggregation(), genome_config.new_activation())
416
+
417
+ @staticmethod
418
+ def create_connection (config , input_id , output_id ):
419
+ connection = DefaultConnectionGene ((input_id , output_id ))
420
+ connection .init_attributes (config )
421
+ return connection
436
422
437
423
@classmethod
438
424
def create_unconnected (cls , config , key ):
@@ -451,9 +437,8 @@ def connect_fs_neat(self, config):
451
437
# TODO: Factor out the gene creation.
452
438
input_id = choice (self .inputs .keys ())
453
439
for output_id in list (self .hidden .keys ()) + list (self .outputs .keys ()):
454
- weight = gauss (0 , config .weight_stdev )
455
- cg = DefaultConnectionGene (input_id , output_id , weight , True )
456
- self .connections [cg .key ] = cg
440
+ connection = self .create_connection (config , input_id , output_id )
441
+ self .connections [connection .key ] = connection
457
442
458
443
def compute_full_connections (self , config ):
459
444
""" Compute connections for a fully-connected feed-forward genome (each input connected to all nodes). """
@@ -466,18 +451,15 @@ def compute_full_connections(self, config):
466
451
467
452
def connect_full (self , config ):
468
453
""" Create a fully-connected genome. """
469
- # TODO: Factor out the gene creation.
470
454
for input_id , output_id in self .compute_full_connections (config ):
471
- weight = gauss (0 , config .weight_stdev )
472
- cg = DefaultConnectionGene (input_id , output_id , weight , True )
473
- self .connections [cg .key ] = cg
455
+ connection = self .create_connection (config , input_id , output_id )
456
+ self .connections [connection .key ] = connection
474
457
475
458
def connect_partial (self , config ):
476
459
assert 0 <= config .connection_fraction <= 1
477
460
all_connections = self .compute_full_connections (config )
478
461
shuffle (all_connections )
479
462
num_to_add = int (round (len (all_connections ) * config .connection_fraction ))
480
- for key in all_connections [:num_to_add ]:
481
- gene = DefaultConnectionGene (key )
482
- gene .init_attributes (config )
483
- self .connections [key ] = gene
463
+ for input_id , output_id in all_connections [:num_to_add ]:
464
+ connection = self .create_connection (config , input_id , output_id )
465
+ self .connections [connection .key ] = connection
0 commit comments