Skip to content

Commit 7149b38

Browse files
Another high-level refactoring pass.
Created the SequentialStatic wrapper to handle the "evolve a feed-forward network for tabular data" case boilerplate code. ActivationFunctionSet now auto-adds the full set of activation functions in the initializer. Config now sets default member values in the initializer, and the load() method must now explicitly be called. Replaced the input, hidden and output node gene sets with a single node gene collection; adjusted network conventions to support this. Pass config object around as necessary to support current refactoring. Added refactoring reminder comments. Expanded feed-forward network tests.
1 parent f0a8a3b commit 7149b38

15 files changed

+669
-344
lines changed

examples/xor/xor2.py

Lines changed: 16 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,25 @@
1-
""" 2-input XOR example """
2-
from __future__ import print_function
3-
4-
import os
5-
6-
from neat import nn, population, statistics
7-
8-
9-
# Network inputs and expected outputs.
10-
xor_inputs = [[0, 0], [0, 1], [1, 0], [1, 1]]
11-
xor_outputs = [0, 1, 1, 0]
12-
13-
def ideal_demo():
14-
import neat
15-
16-
# default Config
17-
# default to parallel processing using auto-detected # hardware cores
18-
n = neat.Sequential(xor_inputs, xor_outputs)
19-
20-
n.evolve(300)
21-
22-
n.save_statistics(".")
23-
24-
print('Number of evaluations: {0}'.format(n.total_evaluations))
25-
26-
# Show output of the most fit genome against training data.
27-
winner = n.best_genome()
28-
print('\nBest genome:\n{!s}'.format(winner))
29-
print('\nOutput:')
30-
winner_output = n.evaluate(winner, xor_inputs)
31-
for inputs, expected, outputs in zip(xor_inputs, xor_outputs, winner_output):
32-
print("input {!r}, expected output {0:1.5f} got {1:1.5f}".format(inputs, expected, outputs[0]))
1+
"""
2+
2-input XOR example -- this is most likely the simplest possible example.
3+
"""
334

5+
from __future__ import print_function
346

7+
import neat
358

36-
total_evaluations = 0
37-
38-
def eval_fitness(genomes):
39-
global total_evaluations
40-
total_evaluations += len(genomes)
41-
42-
for gid, g in genomes:
43-
44-
net = nn.create_feed_forward_phenotype(g)
45-
46-
sum_square_error = 0.0
47-
for inputs, expected in zip(xor_inputs, xor_outputs):
48-
# Serial activation propagates the inputs through the entire network.
49-
output = net.serial_activate(inputs)
50-
sum_square_error += (output[0] - expected) ** 2
51-
52-
# When the output matches expected for all inputs, fitness will reach
53-
# its maximum value of 1.0.
54-
g.fitness = 1 - sum_square_error
55-
56-
57-
local_dir = os.path.dirname(__file__)
58-
config_path = os.path.join(local_dir, 'xor2_config')
59-
pop = population.Population(config_path)
60-
pop.run(eval_fitness, 300)
9+
# Inputs and expected outputs.
10+
xor_inputs = [(0.0, 0.0), (0.0, 1.0), (1.0, 0.0), (1.0, 1.0)]
11+
xor_outputs = [ (0.0,), (1.0,), (1.0,), (0.0,)]
6112

62-
# Log statistics.
63-
statistics.save_stats(pop.statistics)
64-
statistics.save_species_count(pop.statistics)
65-
statistics.save_species_fitness(pop.statistics)
13+
# Create a SequentialStatic instance and use it to evolve a network.
14+
n = neat.SequentialStatic(xor_inputs, xor_outputs)
15+
winner = n.evolve(300)
6616

67-
print('Number of evaluations: {0}'.format(total_evaluations))
17+
# Display the winning genome.
18+
print('\nBest genome:\n{!s}'.format(winner))
6819

6920
# Show output of the most fit genome against training data.
70-
winner = pop.statistics.best_genome()
71-
print('\nBest genome:\n{!s}'.format(winner))
7221
print('\nOutput:')
73-
winner_net = nn.create_feed_forward_phenotype(winner)
74-
for inputs, expected in zip(xor_inputs, xor_outputs):
75-
output = winner_net.serial_activate(inputs)
76-
print("expected {0:1.5f} got {1:1.5f}".format(expected, output[0]))
22+
for inputs, expected, outputs in n.evaluate(winner):
23+
print("input {!r}, expected output {!r}, got {!r}".format(inputs, expected, outputs[0]))
7724

25+
print("Total number of evaluations: {}".format(n.total_evaluations))

examples/xor/xor2_config

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ activation_functions = sigmoid
1919
aggregation_functions = sum
2020
weight_stdev = 1.0
2121

22-
[FFGenome]
22+
[DefaultGenome]
2323

2424
[genetic]
2525
pop_size = 150

neat/__init__.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,2 @@
1-
from neat import activations
1+
from neat.sequential_static import SequentialStatic
22

3-
# TODO: This collection should probably be held by the Config object.
4-
activation_functions = activations.ActivationFunctionSet()
5-
6-
activation_functions.add('sigmoid', activations.sigmoid_activation)
7-
activation_functions.add('tanh', activations.tanh_activation)
8-
activation_functions.add('sin', activations.sin_activation)
9-
activation_functions.add('gauss', activations.gauss_activation)
10-
activation_functions.add('relu', activations.relu_activation)
11-
activation_functions.add('identity', activations.identity_activation)
12-
activation_functions.add('clamped', activations.clamped_activation)
13-
activation_functions.add('inv', activations.inv_activation)
14-
activation_functions.add('log', activations.log_activation)
15-
activation_functions.add('exp', activations.exp_activation)
16-
activation_functions.add('abs', activations.abs_activation)
17-
activation_functions.add('hat', activations.hat_activation)
18-
activation_functions.add('square', activations.square_activation)
19-
activation_functions.add('cube', activations.cube_activation)

neat/activations.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,20 @@ class InvalidActivationFunction(Exception):
7272

7373
class ActivationFunctionSet(object):
7474
def __init__(self):
75-
self.functions = {}
75+
self.functions = {'sigmoid': sigmoid_activation,
76+
'tanh': tanh_activation,
77+
'sin': sin_activation,
78+
'gauss': gauss_activation,
79+
'relu': relu_activation,
80+
'identity': identity_activation,
81+
'clamped': clamped_activation,
82+
'inv': inv_activation,
83+
'log': log_activation,
84+
'exp': exp_activation,
85+
'abs': abs_activation,
86+
'hat': hat_activation,
87+
'square': square_activation,
88+
'cube': cube_activation}
7689

7790
def add(self, config_name, function):
7891
# TODO: Verify that the given function has the correct signature.
@@ -86,6 +99,7 @@ def get(self, config_name):
8699
return f
87100

88101
def is_valid(self, config_name):
102+
# TODO: Verify that the given function has the correct signature.
89103
return config_name in self.functions
90104

91105

neat/config.py

Lines changed: 90 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from random import random, gauss, choice
21
import os
2+
from random import random, gauss, choice
33

4-
from neat.genome import DefaultGenome, FFGenome
5-
from neat import activation_functions
4+
from neat.activations import ActivationFunctionSet
5+
from neat.genome import DefaultGenome
66
from neat.reproduction import DefaultReproduction
77
from neat.stagnation import DefaultStagnation
88

@@ -13,6 +13,7 @@
1313

1414
aggregation_function_defs = {'sum': sum, 'max': max, 'min': min}
1515

16+
1617
class Config(object):
1718
'''
1819
A simple container for all of the user-configurable parameters of NEAT.
@@ -29,14 +30,81 @@ class Config(object):
2930

3031
allowed_connectivity = ['unconnected', 'fs_neat', 'fully_connected', 'partial']
3132

32-
def __init__(self, filename=None):
33+
def __init__(self):
34+
# Initialize type registry with default implementations.
3335
self.registry = {'DefaultStagnation': DefaultStagnation,
3436
'DefaultReproduction': DefaultReproduction,
35-
'DefaultGenome': DefaultGenome,
36-
'FFGenome': FFGenome}
37+
'DefaultGenome': DefaultGenome}
3738
self.type_config = {}
38-
if filename is not None:
39-
self.load(filename)
39+
40+
# Phenotype configuration
41+
self.input_nodes = 0
42+
self.output_nodes = 0
43+
self.hidden_nodes = 0
44+
self.initial_connection = 'unconnected'
45+
self.connection_fraction = None
46+
self.max_weight = 30.0
47+
self.min_weight = -30.0
48+
self.weight_stdev = 1.0
49+
self.activation_functions = ['sigmoid']
50+
self.aggregation_functions = ['sum']
51+
52+
# Genetic algorithm configuration
53+
self.pop_size = 150
54+
self.max_fitness_threshold = -0.05
55+
self.prob_add_conn = 0.5
56+
self.prob_add_node = 0.1
57+
self.prob_delete_conn = 0.1
58+
self.prob_delete_node = 0.05
59+
self.prob_mutate_bias = 0.05
60+
self.bias_mutation_power = 2.0
61+
self.prob_mutate_response = 0.5
62+
self.response_mutation_power = 0.1
63+
self.prob_mutate_weight = 0.5
64+
self.prob_replace_weight = 0.02
65+
self.weight_mutation_power = 0.8
66+
self.prob_mutate_activation = 0.0
67+
self.prob_mutate_aggregation = 0.0
68+
self.prob_toggle_link = 0.01
69+
self.reset_on_extinction = True
70+
71+
# genotype compatibility
72+
self.compatibility_threshold = 3.0
73+
self.excess_coefficient = 1.0
74+
self.disjoint_coefficient = 1.0
75+
self.weight_coefficient = 0.4
76+
77+
stagnation_type_name = 'DefaultStagnation'
78+
self.stagnation_type = self.registry[stagnation_type_name]
79+
# TODO: Look up the default type configuration from a static method on the type?
80+
self.type_config[stagnation_type_name] = {'species_fitness_func': 'mean',
81+
'max_stagnation': 15}
82+
83+
reproduction_type_name = 'DefaultReproduction'
84+
self.reproduction_type = self.registry[reproduction_type_name]
85+
# TODO: Look up the default type configuration from a static method on the type?
86+
self.type_config[reproduction_type_name] = {'elitism': 1,
87+
'survival_threshold': 0.2}
88+
89+
genome_type_name = 'DefaultGenome'
90+
self.genome_type = self.registry[genome_type_name]
91+
# TODO: Look up the default type configuration from a static method on the type?
92+
self.type_config[genome_type_name] = {}
93+
94+
# Gather statistics for each generation.
95+
self.collect_statistics = True
96+
# Show stats after each generation.
97+
self.report = True
98+
# Save the best genome from each generation.
99+
self.save_best = False
100+
# Time in minutes between saving checkpoints, None for no timed checkpoints.
101+
self.checkpoint_time_interval = None
102+
# Time in generations between saving checkpoints, None for no generational checkpoints.
103+
self.checkpoint_gen_interval = None
104+
105+
# Create full set of available activation functions.
106+
# TODO: pick a better name for this member, it's too confusing alongside activation_functions.
107+
self.available_activations = ActivationFunctionSet()
40108

41109
def load(self, filename):
42110
if not os.path.isfile(filename):
@@ -76,7 +144,7 @@ def load(self, filename):
76144

77145
# Verify that specified activation functions are valid.
78146
for fn in self.activation_functions:
79-
if not activation_functions.is_valid(fn):
147+
if not self.available_activations.is_valid(fn):
80148
raise Exception("Invalid activation function name: {0!r}".format(fn))
81149

82150
# Genetic algorithm configuration
@@ -123,28 +191,26 @@ def load(self, filename):
123191
self.genome_type = self.registry[genome_type_name]
124192
self.type_config[genome_type_name] = parameters.items(genome_type_name)
125193

126-
# Gather statistics for each generation.
127-
self.collect_statistics = True
128-
# Show stats after each generation.
129-
self.report = True
130-
# Save the best genome from each generation.
131-
self.save_best = False
132-
# Time in minutes between saving checkpoints, None for no timed checkpoints.
133-
self.checkpoint_time_interval = None
134-
# Time in generations between saving checkpoints, None for no generational checkpoints.
135-
self.checkpoint_gen_interval = None
194+
def set_input_output_sizes(self, num_inputs, num_outputs):
195+
self.input_nodes = num_inputs
196+
self.output_nodes = num_outputs
197+
self.input_keys = [-i-1 for i in range(self.input_nodes)]
198+
self.output_keys = [i for i in range(self.output_nodes)]
199+
200+
def save(self, filename):
201+
pass
136202

137-
def register(self, typeName, typeDef):
203+
def register(self, type_name, type_def):
138204
"""
139205
User-defined classes mentioned in the config file must be provided to the
140206
configuration object before the load() method is called.
141207
"""
142-
self.registry[typeName] = typeDef
208+
self.registry[type_name] = type_def
143209

144-
def get_type_config(self, typeInstance):
145-
return dict(self.type_config[typeInstance.__class__.__name__])
210+
def get_type_config(self, type_instance):
211+
return dict(self.type_config[type_instance.__class__.__name__])
146212

147-
# TODO: Factor out these mutation methods into a separate class.
213+
# TODO: Factor out these mutation methods into a separate class?
148214
def new_weight(self):
149215
return gauss(0, self.weight_stdev)
150216

neat/genes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ class NodeGene(object):
88

99
def __init__(self, key, bias, response, aggregation, activation):
1010
# TODO: Move these asserts into an external validation mechanism that can be omitted at runtime if desired.
11+
# Maybe this class should implement a validate(config) method that can optionally be called
12+
# by the NEAT framework?
1113
# TODO: Validate aggregation and activation against current configuration.
1214
assert type(bias) is float
1315
assert type(response) is float

0 commit comments

Comments
 (0)