Skip to content

Commit d08740e

Browse files
Break out speciation into a user-specifiable class.
StdOutReporter now shows average generation time for last 10 generations. Updated XOR example. Minor fixes/cleanup. Added TODO items.
1 parent c4d146a commit d08740e

11 files changed

+59
-33
lines changed

examples/xor/config-feedforward

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#--- parameters for the XOR-2 experiment ---#
22

3-
# The `NEAT` section specifies parameters particular to the NEAT algorithm
4-
# or the experiment itself. This is the only required section.
53
[NEAT]
64
max_fitness_threshold = 0.9
75
pop_size = 150
@@ -29,7 +27,6 @@ bias_replace_rate = 0.1
2927

3028
# genome compatibility options
3129
compatibility_disjoint_coefficient = 1.0
32-
compatibility_threshold = 3.0
3330
compatibility_weight_coefficient = 0.5
3431

3532
# connection add/remove rates
@@ -70,6 +67,9 @@ weight_mutate_power = 0.5
7067
weight_mutate_rate = 0.8
7168
weight_replace_rate = 0.1
7269

70+
[DefaultSpeciesSet]
71+
compatibility_threshold = 3.0
72+
7373
[DefaultStagnation]
7474
species_fitness_func = max
7575
max_stagnation = 20

examples/xor/evolve-feedforward-parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def run(config_file):
6464
p.add_reporter(stats)
6565

6666
# Run for up to 300 generations.
67-
pe = neat.parallel.ParallelEvaluator(4, eval_genome)
67+
pe = neat.ParallelEvaluator(4, eval_genome)
6868
winner = p.run(pe.evaluate, 300)
6969

7070
# Display the winning genome.

examples/xor/evolve-feedforward.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def eval_genomes(genomes, config):
2424
def run(config_file):
2525
# Load configuration.
2626
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
27-
neat.DefaultStagnation, config_file)
27+
neat.DefaultSpeciesSet, neat.DefaultStagnation,
28+
config_file)
2829

2930
# Create the population, which is the top-level object for a NEAT run.
3031
p = neat.Population(config)

neat/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from neat.reproduction import DefaultReproduction
1010
from neat.stagnation import DefaultStagnation
1111
from neat.reporting import StdOutReporter
12+
from neat.species import DefaultSpeciesSet
1213
from neat.statistics import StatisticsReporter
1314
from neat.parallel import ParallelEvaluator
1415

neat/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,17 @@ class Config(object):
7373
ConfigParameter('max_fitness_threshold', float),
7474
ConfigParameter('reset_on_extinction', bool)]
7575

76-
def __init__(self, genome_type, reproduction_type, stagnation_type, filename):
76+
def __init__(self, genome_type, reproduction_type, species_set_type, stagnation_type, filename):
7777
# Check that the provided types have the required methods.
7878
assert hasattr(genome_type, 'parse_config')
7979
assert hasattr(reproduction_type, 'parse_config')
80+
assert hasattr(species_set_type, 'parse_config')
8081
assert hasattr(stagnation_type, 'parse_config')
8182

8283
self.genome_type = genome_type
83-
self.stagnation_type = stagnation_type
8484
self.reproduction_type = reproduction_type
85+
self.species_set_type = species_set_type
86+
self.stagnation_type = stagnation_type
8587

8688
if not os.path.isfile(filename):
8789
raise Exception('No such config file: ' + os.path.abspath(filename))
@@ -104,6 +106,9 @@ def __init__(self, genome_type, reproduction_type, stagnation_type, filename):
104106
genome_dict = dict(parameters.items(genome_type.__name__))
105107
self.genome_config = genome_type.parse_config(genome_dict)
106108

109+
species_set_dict = dict(parameters.items(species_set_type.__name__))
110+
self.species_set_config = species_set_type.parse_config(species_set_dict)
111+
107112
stagnation_dict = dict(parameters.items(stagnation_type.__name__))
108113
self.stagnation_config = stagnation_type.parse_config(stagnation_dict)
109114

neat/genome.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from neat.config import ConfigParameter, write_pretty_params
55
from neat.genes import DefaultConnectionGene, DefaultNodeGene
6-
from neat.six_util import iteritems, itervalues, iterkeys
6+
from neat.six_util import iteritems, iterkeys
77

88
from neat.activations import ActivationFunctionSet
99
from neat.graphs import creates_cycle
@@ -20,7 +20,6 @@ class DefaultGenomeConfig(object):
2020
ConfigParameter('num_outputs', int),
2121
ConfigParameter('num_hidden', int),
2222
ConfigParameter('feed_forward', bool),
23-
ConfigParameter('compatibility_threshold', float),
2423
ConfigParameter('compatibility_disjoint_coefficient', float),
2524
ConfigParameter('compatibility_weight_coefficient', float),
2625
ConfigParameter('conn_add_prob', float),
@@ -67,18 +66,17 @@ def add_activation(self, name, func):
6766
self.activation_defs.add(name, func)
6867

6968
def save(self, f):
70-
# TODO: Handle the initial_connection setting.
71-
# f.write('initial_connection = {0}\n'.format(self.initial_connection))
69+
f.write('initial_connection = {0}\n'.format(self.initial_connection))
7270
# Verify that initial connection type is valid.
73-
# self.initial_connection = params.get('', 'unconnected')
74-
# if 'partial' in self.initial_connection:
75-
# c, p = self.initial_connection.split()
76-
# self.initial_connection = c
77-
# self.connection_fraction = float(p)
78-
# if not (0 <= self.connection_fraction <= 1):
79-
# raise Exception("'partial' connection value must be between 0.0 and 1.0, inclusive.")
80-
#
81-
# assert self.initial_connection in self.allowed_connectivity
71+
if 'partial' in self.initial_connection:
72+
c, p = self.initial_connection.split()
73+
self.initial_connection = c
74+
self.connection_fraction = float(p)
75+
if not (0 <= self.connection_fraction <= 1):
76+
raise Exception("'partial' connection value must be between 0.0 and 1.0, inclusive.")
77+
78+
assert self.initial_connection in self.allowed_connectivity
79+
8280
write_pretty_params(f, self, self.__params)
8381

8482

@@ -322,9 +320,7 @@ def distance(self, other, config):
322320
connection_distance = (connection_distance + config.compatibility_disjoint_coefficient * disjoint_connections) / max_conn
323321

324322
distance = node_distance + connection_distance
325-
compatible = distance < config.compatibility_threshold
326-
327-
return distance, compatible
323+
return distance
328324

329325
def size(self):
330326
'''Returns genome 'complexity', taken to be (number of nodes, number of enabled connections)'''
@@ -392,7 +388,7 @@ def create_unconnected(cls, config, key):
392388

393389
def connect_fs_neat(self, config):
394390
""" Randomly connect one input to all hidden and output nodes (FS-NEAT). """
395-
input_id = choice(self.inputs.keys())
391+
input_id = choice(config.input_keys)
396392
for output_id in list(self.hidden.keys()) + list(self.outputs.keys()):
397393
connection = self.create_connection(config, input_id, output_id)
398394
self.connections[connection.key] = connection

neat/graphs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def creates_cycle(connections, test):
1010
if i == o:
1111
return True
1212

13-
visited = set([o])
13+
visited = {o}
1414
while True:
1515
num_added = 0
1616
for a, b in connections:

neat/population.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import print_function
22

33
from neat.reporting import ReporterSet
4-
from neat.species import SpeciesSet
54
from neat.six_util import iteritems, itervalues
65

76

@@ -21,7 +20,7 @@ def __init__(self, config, initial_state=None):
2120
if initial_state is None:
2221
# Create a population from scratch, then partition into species.
2322
self.population = self.reproduction.create_new(config.genome_type, config.genome_config, config.pop_size)
24-
self.species = SpeciesSet(config)
23+
self.species = config.species_set_type(config)
2524
self.species.speciate(config, self.population)
2625
self.generation = -1
2726
else:

neat/reporting.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from neat.math_util import mean, stdev
66
from neat.six_util import itervalues
77

8+
# TODO: Add a curses-based reporter.
89

910
class ReporterSet(object):
1011
def __init__(self):
@@ -87,14 +88,22 @@ class StdOutReporter(BaseReporter):
8788
def __init__(self):
8889
self.generation = None
8990
self.generation_start_time = None
91+
self.generation_times = []
9092

9193
def start_generation(self, generation):
9294
self.generation = generation
9395
print('\n ****** Running generation {0} ****** \n'.format(generation))
9496
self.generation_start_time = time.time()
9597

9698
def end_generation(self):
97-
print("Generation time: {0:.3f} sec".format(time.time() - self.generation_start_time))
99+
elapsed = time.time() - self.generation_start_time
100+
self.generation_times.append(elapsed)
101+
self.generation_times = self.generation_times[-10:]
102+
average = sum(self.generation_times) / len(self.generation_times)
103+
if len(self.generation_times) > 1:
104+
print("Generation time: {0:.3f} sec ({1:.3f} average)".format(elapsed, average))
105+
else:
106+
print("Generation time: {0:.3f} sec".format(elapsed))
98107

99108
def loading_checkpoint(self, filename):
100109
print('Resuming from a previous point: ' + filename)

neat/reproduction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class DefaultReproduction(object):
1717
scheme: explicit fitness sharing with fixed-time species stagnation.
1818
"""
1919

20+
# TODO: Create a separate configuration class instead of using a dict (for consistency with other types).
2021
@classmethod
2122
def parse_config(cls, param_dict):
2223
config = {'elitism': 1,

neat/species.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,26 @@ def median_fitness(self):
3434
return fitnesses[len(fitnesses) // 2]
3535

3636

37-
class SpeciesSet(object):
38-
"""
39-
Encapsulates the speciation scheme.
40-
"""
37+
class DefaultSpeciesSet(object):
38+
""" Encapsulates the default speciation scheme. """
4139

4240
def __init__(self, config):
4341
self.indexer = Indexer(1)
4442
self.species = {}
4543
self.to_species = {}
4644

45+
# TODO: Create a separate configuration class instead of using a dict (for consistency with other types).
46+
@classmethod
47+
def parse_config(cls, param_dict):
48+
config = {'compatibility_threshold': float(param_dict['compatibility_threshold'])}
49+
50+
return config
51+
52+
@classmethod
53+
def write_config(cls, f, param_dict):
54+
compatibility_threshold = param_dict['compatibility_threshold']
55+
f.write('compatibility_threshold = {}\n'.format(compatibility_threshold))
56+
4757
def speciate(self, config, population):
4858
"""
4959
Place genomes into species by genetic similarity.
@@ -56,6 +66,8 @@ def speciate(self, config, population):
5666
"""
5767
assert type(population) is dict
5868

69+
compatibility_threshold = config.species_set_config['compatibility_threshold']
70+
5971
# Reset all species member lists.
6072
for s in itervalues(self.species):
6173
s.members.clear()
@@ -69,7 +81,9 @@ def speciate(self, config, population):
6981
closest_species_id = None
7082
for sid, s in iteritems(self.species):
7183
rep = s.representative
72-
distance, compatible = individual.distance(rep, config.genome_config)
84+
distance = individual.distance(rep, config.genome_config)
85+
compatible = distance < compatibility_threshold
86+
7387
if compatible and distance < min_distance:
7488
closest_species = s
7589
closest_species_id = sid

0 commit comments

Comments
 (0)