Skip to content

Commit f0a8a3b

Browse files
More refactoring.
Refactoring of statistics handling. Added FFGenome test. Added temporary debugging code. Renamed Genome to DefaultGenome. Removed 'feedforward' configuration item in favor of explicitly specifying a genome type. Fixed Config.new_activation bug that chose from all functions instead of the ones specified in the config file. Fixed accidental double usage of mutation probabilities in NodeGene.mutate. Added refactoring reminder comments. Restored keys (formerly called ID) to species/genome/gene class instances because tracking ids separately is more painful.
1 parent 8de3e6b commit f0a8a3b

12 files changed

+222
-202
lines changed

examples/xor/xor2_config

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
[Types]
77
stagnation_type = DefaultStagnation
88
reproduction_type = DefaultReproduction
9+
genome_type = DefaultGenome
910

1011
[phenotype]
1112
input_nodes = 2
@@ -14,11 +15,12 @@ output_nodes = 1
1415
initial_connection = unconnected
1516
max_weight = 30
1617
min_weight = -30
17-
feedforward = 1
1818
activation_functions = sigmoid
1919
aggregation_functions = sum
2020
weight_stdev = 1.0
2121

22+
[FFGenome]
23+
2224
[genetic]
2325
pop_size = 150
2426
max_fitness_threshold = 0.95

neat/config.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from random import random, gauss, choice
22
import os
33

4-
from neat.genome import Genome, FFGenome
4+
from neat.genome import DefaultGenome, FFGenome
55
from neat import activation_functions
66
from neat.reproduction import DefaultReproduction
77
from neat.stagnation import DefaultStagnation
@@ -31,7 +31,9 @@ class Config(object):
3131

3232
def __init__(self, filename=None):
3333
self.registry = {'DefaultStagnation': DefaultStagnation,
34-
'DefaultReproduction': DefaultReproduction}
34+
'DefaultReproduction': DefaultReproduction,
35+
'DefaultGenome': DefaultGenome,
36+
'FFGenome': FFGenome}
3537
self.type_config = {}
3638
if filename is not None:
3739
self.load(filename)
@@ -58,7 +60,6 @@ def load(self, filename):
5860
self.connection_fraction = None
5961
self.max_weight = float(parameters.get('phenotype', 'max_weight'))
6062
self.min_weight = float(parameters.get('phenotype', 'min_weight'))
61-
self.feedforward = bool(int(parameters.get('phenotype', 'feedforward')))
6263
self.weight_stdev = float(parameters.get('phenotype', 'weight_stdev'))
6364
self.activation_functions = parameters.get('phenotype', 'activation_functions').strip().split()
6465
self.aggregation_functions = parameters.get('phenotype', 'aggregation_functions').strip().split()
@@ -78,12 +79,6 @@ def load(self, filename):
7879
if not activation_functions.is_valid(fn):
7980
raise Exception("Invalid activation function name: {0!r}".format(fn))
8081

81-
# Select a genotype class.
82-
if self.feedforward:
83-
self.genotype = FFGenome
84-
else:
85-
self.genotype = Genome
86-
8782
# Genetic algorithm configuration
8883
self.pop_size = int(parameters.get('genetic', 'pop_size'))
8984
self.max_fitness_threshold = float(parameters.get('genetic', 'max_fitness_threshold'))
@@ -111,6 +106,7 @@ def load(self, filename):
111106

112107
stagnation_type_name = parameters.get('Types', 'stagnation_type')
113108
reproduction_type_name = parameters.get('Types', 'reproduction_type')
109+
genome_type_name = parameters.get('Types', 'genome_type')
114110

115111
if stagnation_type_name not in self.registry:
116112
raise Exception('Unknown stagnation type: {!r}'.format(stagnation_type_name))
@@ -122,6 +118,11 @@ def load(self, filename):
122118
self.reproduction_type = self.registry[reproduction_type_name]
123119
self.type_config[reproduction_type_name] = parameters.items(reproduction_type_name)
124120

121+
if genome_type_name not in self.registry:
122+
raise Exception('Unknown reproduction type: {!r}'.format(reproduction_type_name))
123+
self.genome_type = self.registry[genome_type_name]
124+
self.type_config[genome_type_name] = parameters.items(genome_type_name)
125+
125126
# Gather statistics for each generation.
126127
self.collect_statistics = True
127128
# Show stats after each generation.
@@ -151,13 +152,13 @@ def new_bias(self):
151152
return gauss(0, self.weight_stdev)
152153

153154
def new_response(self):
154-
return 1.0
155+
return 5.0
155156

156157
def new_aggregation(self):
157158
return choice(self.aggregation_functions)
158159

159160
def new_activation(self):
160-
return choice(list(activation_functions.functions.keys()))
161+
return choice(self.activation_functions)
161162

162163
def mutate_weight(self, weight):
163164
if random() < self.prob_mutate_weight:

neat/genes.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,10 @@ def copy(self):
4545

4646
# TODO: Factor out mutation into a separate class.
4747
def mutate(self, config):
48-
if random() < config.prob_mutate_bias:
49-
self.bias = config.mutate_bias(self.bias)
50-
51-
if random() < config.prob_mutate_response:
52-
self.response = config.mutate_response(self.response)
53-
54-
if random() < config.prob_mutate_aggregation:
55-
self.aggregation = config.mutate_aggregation(self.aggregation)
56-
57-
if random() < config.prob_mutate_activation:
58-
self.activation = config.mutate_activation(self.activation)
48+
self.bias = config.mutate_bias(self.bias)
49+
self.response = config.mutate_response(self.response)
50+
self.aggregation = config.mutate_aggregation(self.aggregation)
51+
self.activation = config.mutate_activation(self.activation)
5952

6053

6154
# TODO: Evaluate using __slots__ for performance/memory usage improvement.

neat/genome.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from random import choice, gauss, randint, random, shuffle
66

77

8-
class Genome(object):
8+
class DefaultGenome(object):
99
""" A genome for generalized neural networks. """
1010

1111
def __init__(self, key):
12-
# (id, gene) pairs for gene sets.
1312
self.key = key
13+
14+
# (id, gene) pairs for gene sets.
1415
self.connections = {}
1516
self.hidden = {}
1617
self.inputs = {}
@@ -54,11 +55,6 @@ def mutate(self, config):
5455

5556
def crossover(self, other, key):
5657
""" Crosses over parents' genomes and returns a child. """
57-
58-
# Parents must belong to the same species.
59-
#assert self.species_id == other.species_id, 'Different parents species ID: {0} vs {1}'.format(self.species_id,
60-
# other.species_id)
61-
6258
if self.fitness > other.fitness:
6359
parent1 = self
6460
parent2 = other
@@ -68,11 +64,8 @@ def crossover(self, other, key):
6864

6965
# creates a new child
7066
child = self.__class__(key)
71-
7267
child.inherit_genes(parent1, parent2)
7368

74-
#child.species_id = parent1.species_id
75-
7669
return child
7770

7871
def inherit_genes(self, parent1, parent2):
@@ -288,7 +281,7 @@ def add_hidden_nodes(self, num_hidden, config):
288281

289282
@classmethod
290283
def create(cls, config, key):
291-
g = config.genotype.create_unconnected(config, key)
284+
g = config.genome_type.create_unconnected(config, key)
292285

293286
# Add hidden nodes if requested.
294287
if config.hidden_nodes > 0:
@@ -376,7 +369,11 @@ def connect_partial(self, config):
376369

377370

378371

379-
class FFGenome(Genome):
372+
# TODO: This class only differs from DefaultGenome in mutation behavior and
373+
# the node_order member. Its complexity suggests the bar is set too high
374+
# for creating user-defined genome types.
375+
376+
class FFGenome(DefaultGenome):
380377
""" A genome for feed-forward neural networks. Feed-forward
381378
topologies are a particular case of Recurrent NNs.
382379
"""

neat/nn/__init__.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from neat import activation_functions
44
from neat.six_util import iterkeys, itervalues
5+
from neat.config import aggregation_function_defs
56

67

78
def find_feed_forward_layers(inputs, connections):
@@ -50,11 +51,15 @@ def serial_activate(self, inputs):
5051
for i, v in zip(self.input_nodes, inputs):
5152
self.values[i] = v
5253

53-
for node, func, bias, response, links in self.node_evals:
54-
s = 0.0
54+
for node, agg_func, act_func, bias, response, links in self.node_evals:
55+
#print(node, func, bias, response, links)
56+
node_inputs = []
5557
for i, w in links:
56-
s += self.values[i] * w
57-
self.values[node] = func(bias + response * s)
58+
node_inputs.append(self.values[i] * w)
59+
s = agg_func(node_inputs)
60+
self.values[node] = act_func(bias + response * s)
61+
print(" v[{}] = {}({} + {} * {} = {}) = {}".format(node, act_func, bias, response, s, bias + response * s, self.values[node]))
62+
print(self.values)
5863

5964
return [self.values[i] for i in self.output_nodes]
6065

@@ -74,23 +79,25 @@ def create_feed_forward_phenotype(genome):
7479

7580
layers = find_feed_forward_layers(input_nodes, connections)
7681
node_evals = []
77-
#used_nodes = set(input_nodes + output_nodes)
7882
max_used_node = max(max(input_nodes), max(output_nodes))
7983
for layer in layers:
8084
for node in layer:
8185
inputs = []
86+
node_expr = []
8287
# TODO: This could be more efficient.
8388
for cg in itervalues(genome.connections):
8489
if cg.output == node and cg.enabled:
8590
inputs.append((cg.input, cg.weight))
86-
#used_nodes.add(cg.in_node_id)
91+
node_expr.append("v[%d] * %f" % (cg.input, cg.weight))
8792
max_used_node = max(max_used_node, cg.input)
8893

89-
#used_nodes.add(node)
9094
max_used_node = max(max_used_node, node)
9195
ng = all_nodes[node]
96+
aggregation_function = aggregation_function_defs[ng.aggregation]
9297
activation_function = activation_functions.get(ng.activation)
93-
node_evals.append((node, activation_function, ng.bias, ng.response, inputs))
98+
node_evals.append((node, aggregation_function, activation_function, ng.bias, ng.response, inputs))
99+
100+
print(" v[%d] = %s(%f + %f * %s(%s))" % (node, ng.activation, ng.bias, ng.response, ng.aggregation, ", ".join(node_expr)))
94101

95102
return FeedForwardNetwork(max_used_node, input_nodes, output_nodes, node_evals)
96103

neat/population.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import time
88

99
from neat.config import Config
10-
from neat.reporting import ReporterSet, StatisticsReporter, StdOutReporter
10+
from neat.reporting import ReporterSet, StdOutReporter
11+
from neat.statistics import StatisticsReporter
1112
from neat.species import SpeciesSet
1213
from neat.six_util import iteritems, itervalues
1314

@@ -120,18 +121,16 @@ def run(self, fitness_function, n):
120121
# genome doesn't change--in these cases, evaluating unmodified elites in each
121122
# generation is a waste of time. The user can always take care of this in their
122123
# fitness function in the time being if they wish.
123-
fitness_function(list(iteritems(self.population)))
124+
fitness_function(list(itervalues(self.population)))
124125
#self.total_evaluations += len(self.population)
125126

126127
# Gather and report statistics.
127-
best_id = None
128128
best = None
129129
best_fitness = -sys.float_info.max
130-
for k, v in iteritems(self.population):
131-
if v.fitness > best_fitness:
132-
best = v
133-
best_id = k
134-
self.reporters.post_evaluate(self.population, self.species, best_id, best)
130+
for g in itervalues(self.population):
131+
if g.fitness > best_fitness:
132+
best = g
133+
self.reporters.post_evaluate(self.population, self.species, best)
135134

136135
# Save the best genome from the current generation if requested.
137136
if self.config.save_best:

neat/reporting.py

Lines changed: 8 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from __future__ import print_function
22

3-
import copy
43
import time
5-
import sys
64

75
from neat.math_util import mean, stdev
8-
from neat.six_util import iteritems, itervalues
6+
from neat.six_util import itervalues
97

108

119
class ReporterSet(object):
@@ -34,9 +32,9 @@ def saving_checkpoint(self, checkpoint_type, filename):
3432
for r in self.reporters:
3533
r.saving_checkpoint(checkpoint_type, filename)
3634

37-
def post_evaluate(self, population, species, best_id, best):
35+
def post_evaluate(self, population, species, best_genome):
3836
for r in self.reporters:
39-
r.post_evaluate(population, species, best_id, best)
37+
r.post_evaluate(population, species, best_genome)
4038

4139
def complete_extinction(self):
4240
for r in self.reporters:
@@ -69,7 +67,7 @@ def loading_checkpoint(self, filename):
6967
def saving_checkpoint(self, checkpoint_type, filename):
7068
pass
7169

72-
def post_evaluate(self, population, species, best_id, best):
70+
def post_evaluate(self, population, species, best_genome):
7371
pass
7472

7573
def complete_extinction(self):
@@ -105,14 +103,14 @@ def saving_checkpoint(self, checkpoint_type, filename):
105103
print('Creating {0} checkpoint file {1} at generation: {0}'.format(
106104
checkpoint_type, filename, self.generation))
107105

108-
def post_evaluate(self, population, species, best_id, best):
106+
def post_evaluate(self, population, species, best_genome):
109107
fitnesses = [c.fitness for c in itervalues(population)]
110108
fit_mean = mean(fitnesses)
111109
fit_std = stdev(fitnesses)
112-
best_species_id = species.get_species_id(best_id)
110+
best_species_id = species.get_species_id(best_genome.key)
113111
print('Population\'s average fitness: {0:3.5f} stdev: {1:3.5f}'.format(fit_mean, fit_std))
114-
print('Best fitness: {0:3.5f} - size: {1!r} - species {2} - id {3}'.format(best.fitness, best.size(),
115-
best_species_id, best_id))
112+
print('Best fitness: {0:3.5f} - size: {1!r} - species {2} - id {3}'.format(best_genome.fitness, best_genome.size(),
113+
best_species_id, best_genome.key))
116114
print('Species length: {0:d} totaling {1:d} individuals'.format(len(species.species), len(population)))
117115
#print('Species ID : {0!s}'.format([s.ID for s in species]))
118116
#print('Species size : {0!s}'.format([len(s.members) for s in species]))
@@ -130,69 +128,3 @@ def species_stagnant(self, sid, species):
130128

131129
def info(self, msg):
132130
print(msg)
133-
134-
135-
class StatisticsReporter(BaseReporter):
136-
def __init__(self):
137-
BaseReporter.__init__(self)
138-
self.most_fit_genomes = []
139-
self.generation_statistics = []
140-
self.generation_cross_validation_statistics = []
141-
142-
def post_evaluate(self, population, species, best_id, best):
143-
self.most_fit_genomes.append(copy.deepcopy(best))
144-
145-
# Store the fitnesses of the members of each currently active species.
146-
species_stats = {}
147-
species_cross_validation_stats = {}
148-
for sid, s in iteritems(species.species):
149-
species_stats[sid] = dict((k, v.fitness) for k, v in iteritems(s.members))
150-
species_cross_validation_stats[sid] = dict((k, v.cross_fitness) for k, v in iteritems(s.members))
151-
self.generation_statistics.append(species_stats)
152-
self.generation_cross_validation_statistics.append(species_cross_validation_stats)
153-
154-
def get_average_fitness(self):
155-
"""Get the per-generation average fitness."""
156-
avg_fitness = []
157-
for stats in self.generation_statistics:
158-
scores = []
159-
for fitness in stats.values():
160-
scores.extend(fitness)
161-
avg_fitness.append(mean(scores))
162-
163-
return avg_fitness
164-
165-
def get_average_cross_validation_fitness(self):
166-
"""Get the per-generation average cross_validation fitness."""
167-
avg_cross_validation_fitness = []
168-
for stats in self.generation_cross_validation_statistics:
169-
scores = []
170-
for fitness in stats.values():
171-
scores.extend(fitness)
172-
avg_cross_validation_fitness.append(mean(scores))
173-
174-
return avg_cross_validation_fitness
175-
176-
def best_unique_genomes(self, n):
177-
"""Returns the most n fit genomes, with no duplication."""
178-
best_unique = {}
179-
for g in self.most_fit_genomes:
180-
best_unique[g.ID] = g
181-
best_unique = list(best_unique.values())
182-
183-
def key(genome):
184-
return genome.fitness
185-
186-
return sorted(best_unique, key=key, reverse=True)[:n]
187-
188-
def best_genomes(self, n):
189-
"""Returns the n most fit genomes ever seen."""
190-
def key(g):
191-
return g.fitness
192-
193-
return sorted(self.most_fit_genomes, key=key, reverse=True)[:n]
194-
195-
def best_genome(self):
196-
"""Returns the most fit genome ever seen."""
197-
return self.best_genomes(1)[0]
198-

0 commit comments

Comments
 (0)