Skip to content

Commit a9aeb21

Browse files
ActivationFunctionSet now uses its own add method to add the default set of functions.
Added minimal type validation for Config initializer types. Updated add_hidden_nodes to work with current implementation.
1 parent d93c1ea commit a9aeb21

File tree

3 files changed

+39
-44
lines changed

3 files changed

+39
-44
lines changed

neat/activations.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -82,31 +82,32 @@ def validate_activation(function):
8282

8383
class ActivationFunctionSet(object):
8484
def __init__(self):
85-
self.functions = {'sigmoid': sigmoid_activation,
86-
'tanh': tanh_activation,
87-
'sin': sin_activation,
88-
'gauss': gauss_activation,
89-
'relu': relu_activation,
90-
'identity': identity_activation,
91-
'clamped': clamped_activation,
92-
'inv': inv_activation,
93-
'log': log_activation,
94-
'exp': exp_activation,
95-
'abs': abs_activation,
96-
'hat': hat_activation,
97-
'square': square_activation,
98-
'cube': cube_activation}
99-
100-
def add(self, config_name, function):
85+
self.functions = {}
86+
self.add('sigmoid', sigmoid_activation)
87+
self.add('tanh', tanh_activation)
88+
self.add('sin', sin_activation)
89+
self.add('gauss', gauss_activation)
90+
self.add('relu', relu_activation)
91+
self.add('identity', identity_activation)
92+
self.add('clamped', clamped_activation)
93+
self.add('inv', inv_activation)
94+
self.add('log', log_activation)
95+
self.add('exp', exp_activation)
96+
self.add('abs', abs_activation)
97+
self.add('hat', hat_activation)
98+
self.add('square', square_activation)
99+
self.add('cube', cube_activation)
100+
101+
def add(self, name, function):
101102
validate_activation(function)
102-
self.functions[config_name] = function
103+
self.functions[name] = function
103104

104-
def get(self, config_name):
105-
f = self.functions.get(config_name)
105+
def get(self, name):
106+
f = self.functions.get(name)
106107
if f is None:
107-
raise InvalidActivationFunction("No such function: {0!r}".format(config_name))
108+
raise InvalidActivationFunction("No such activation function: {0!r}".format(name))
108109

109110
return f
110111

111-
def is_valid(self, config_name):
112-
return config_name in self.functions
112+
def is_valid(self, name):
113+
return name in self.functions

neat/config.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@ def write_pretty_params(f, config, params):
6767

6868

6969
class Config(object):
70-
'''
71-
A simple container for all of the user-configurable parameters of NEAT.
72-
'''
70+
''' A simple container for user-configurable parameters of NEAT. '''
7371

7472
__params = [ConfigParameter('pop_size', int),
7573
ConfigParameter('max_fitness_threshold', float),
@@ -79,14 +77,19 @@ class Config(object):
7977
ConfigParameter('save_best', bool)]
8078

8179
def __init__(self, genome_type, reproduction_type, stagnation_type, filename):
80+
# Check that the provided types have the required methods.
81+
assert hasattr(genome_type, 'parse_config')
82+
assert hasattr(reproduction_type, 'parse_config')
83+
assert hasattr(stagnation_type, 'parse_config')
84+
8285
self.genome_type = genome_type
8386
self.stagnation_type = stagnation_type
8487
self.reproduction_type = reproduction_type
8588

8689
if not os.path.isfile(filename):
8790
raise Exception('No such config file: ' + os.path.abspath(filename))
88-
parameters = ConfigParser()
8991

92+
parameters = ConfigParser()
9093
with open(filename) as f:
9194
if hasattr(parameters, 'read_file'):
9295
parameters.read_file(f)

neat/genome.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@ def __init__(self, params):
3030
# Create full set of available activation functions.
3131
self.activation_defs = ActivationFunctionSet()
3232
self.activation_options = params.get('activation_options', 'sigmoid').strip().split()
33-
34-
# TODO: Verify that specified activation functions are valid before using them.
35-
# for fn in self.activation:
36-
# if not self.activation_defs.is_valid(fn):
37-
# raise Exception("Invalid activation function name: {0!r}".format(fn))
38-
3933
self.aggregation_options = params.get('aggregation_options', 'sum').strip().split()
4034

4135
# Gather configuration data from the gene classes.
@@ -195,7 +189,7 @@ def inherit_genes(self, parent1, parent2):
195189
# Homologous gene: combine genes from both parents.
196190
self.nodes[key] = ng1.crossover(ng2)
197191

198-
def get_new_hidden_id(self):
192+
def get_new_node_key(self):
199193
new_id = 0
200194
while new_id in self.nodes:
201195
new_id += 1
@@ -207,7 +201,7 @@ def mutate_add_node(self, config):
207201

208202
# Choose a random connection to split
209203
conn_to_split = choice(list(self.connections.values()))
210-
new_node_id = self.get_new_hidden_id()
204+
new_node_id = self.get_new_node_key()
211205
ng = self.create_node(config, new_node_id)
212206
self.nodes[new_node_id] = ng
213207

@@ -380,23 +374,20 @@ def __str__(self):
380374
s += "\n\t" + str(c)
381375
return s
382376

383-
def add_hidden_nodes(self, num_hidden, config):
384-
node_id = self.get_new_hidden_id()
385-
for i in range(num_hidden):
386-
# TODO: factor out new node creation.
387-
act_func = choice(config.activation)
388-
node_gene = config.node_gene_type(activation_type=act_func)
389-
assert node_id not in self.hidden
390-
self.hidden[node_id] = node_gene
391-
node_id += 1
377+
def add_hidden_nodes(self, config):
378+
for i in range(config.num_hidden):
379+
node_key = self.get_new_node_key()
380+
assert node_key not in self.nodes
381+
node = self.__class__.create_node(config, node_key)
382+
self.nodes[node_key] = node
392383

393384
@classmethod
394385
def create(cls, config, key):
395386
g = cls.create_unconnected(config, key)
396387

397388
# Add hidden nodes if requested.
398389
if config.num_hidden > 0:
399-
g.add_hidden_nodes(config.num_hidden)
390+
g.add_hidden_nodes(config)
400391

401392
# Add connections based on initial connectivity type.
402393
if config.initial_connection == 'fs_neat':

0 commit comments

Comments
 (0)