Skip to content

Commit d5ba24b

Browse files
Removed explicit support for statistics gathering and checkpointing from Population, as this can be implemented via reporters.
Modified genome distance computation to consider "excess" genes to be the same as disjoint. Refactoring of gene-specific distance computation into gene classes. Renamed the "fully_connected" connection option to "full". Fix minor annoyances in __str__ method formatting. Import commonly used items into the package-level namespace. Update create_feed_forward_phenotype to work with current genome implementation. Updated XOR example to work with current implementation.
1 parent 3753863 commit d5ba24b

File tree

8 files changed

+194
-252
lines changed

8 files changed

+194
-252
lines changed

examples/xor/xor2.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,55 @@
33
"""
44

55
from __future__ import print_function
6-
6+
import os
77
import neat
8+
import visualize
89

9-
# Inputs and expected outputs.
10+
# 2-input XOR inputs and expected outputs.
1011
xor_inputs = [(0.0, 0.0), (0.0, 1.0), (1.0, 0.0), (1.0, 1.0)]
1112
xor_outputs = [ (0.0,), (1.0,), (1.0,), (0.0,)]
1213

13-
# Create a SequentialStatic instance and use it to evolve a network.
14-
n = neat.SequentialStatic(xor_inputs, xor_outputs)
15-
winner = n.evolve(300)
1614

17-
# Display the winning genome.
18-
print('\nBest genome:\n{!s}'.format(winner))
15+
def eval_genomes(genomes, config):
16+
for genome_id, genome in genomes:
17+
genome.fitness = 1.0
18+
net = neat.nn.create_feed_forward_phenotype(genome, config)
19+
for xi, xo in zip(xor_inputs, xor_outputs):
20+
output = net.activate(xi)
21+
genome.fitness -= (output[0] - xo[0]) ** 2
22+
23+
24+
def run(config_file):
25+
# Load configuration.
26+
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
27+
neat.DefaultStagnation, config_file)
28+
29+
# Create the population, which is the top-level object for a NEAT run.
30+
p = neat.Population(config)
31+
32+
# Add a stdout reporter to show progress in the terminal.
33+
p.add_reporter(neat.StdOutReporter())
34+
35+
# Run for up to 300 generations.
36+
winner = p.run(eval_genomes, 300)
37+
38+
# Display the winning genome.
39+
print('\nBest genome:\n{!s}'.format(winner))
40+
41+
# Show output of the most fit genome against training data.
42+
print('\nOutput:')
43+
winner_net = neat.nn.create_feed_forward_phenotype(winner, config)
44+
for xi, xo in zip(xor_inputs, xor_outputs):
45+
output = winner_net.activate(xi)
46+
print("input {!r}, expected output {!r}, got {!r}".format(xi, xo, output))
1947

20-
# Show output of the most fit genome against training data.
21-
print('\nOutput:')
22-
for inputs, expected, outputs in n.evaluate(winner):
23-
print("input {!r}, expected output {!r}, got {!r}".format(inputs, expected, outputs[0]))
48+
node_names = {-1:'A', -2: 'B', 0:'A XOR B'}
49+
visualize.draw_net(config, winner, True, node_names = node_names)
2450

25-
print("Total number of evaluations: {}".format(n.total_evaluations))
51+
if __name__ == '__main__':
52+
# Determine path to configuration file. This path manipulation is
53+
# here so that the script will run successfully regardless of the
54+
# current working directory.
55+
local_dir = os.path.dirname(__file__)
56+
config_path = os.path.join(local_dir, 'xor2_config')
57+
run(config_path)

examples/xor/xor2_config

Lines changed: 59 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,69 @@
11
#--- parameters for the XOR-2 experiment ---#
22

3-
# The `Types` section specifies which classes should be used for various
4-
# tasks in the NEAT algorithm. If you use a non-default class here, you
5-
# must register it with your Config instance before loading the config file.
6-
[Types]
7-
stagnation_type = DefaultStagnation
8-
reproduction_type = DefaultReproduction
9-
genome_type = DefaultGenome
10-
11-
[phenotype]
12-
input_nodes = 2
13-
hidden_nodes = 0
14-
output_nodes = 1
15-
initial_connection = unconnected
16-
max_weight = 30
17-
min_weight = -30
18-
activation_functions = sigmoid
19-
aggregation_functions = sum
20-
weight_stdev = 1.0
3+
# The `NEAT` section specifies parameters particular to the NEAT algorithm
4+
# or the experiment itself. This is the only required section.
5+
[NEAT]
6+
pop_size = 150
7+
max_fitness_threshold = 0.95
8+
reset_on_extinction = False
219

2210
[DefaultGenome]
23-
24-
[genetic]
25-
pop_size = 150
26-
max_fitness_threshold = 0.95
27-
prob_add_conn = 0.988
28-
prob_add_node = 0.085
29-
prob_delete_conn = 0.146
30-
prob_delete_node = 0.0352
31-
prob_mutate_bias = 0.0509
32-
bias_mutation_power = 2.093
33-
prob_mutate_response = 0.1
34-
response_mutation_power = 0.1
35-
prob_mutate_weight = 0.460
36-
prob_replace_weight = 0.0245
37-
weight_mutation_power = 0.825
38-
prob_mutate_activation = 0.0
39-
prob_mutate_aggregation = 0.0
40-
prob_toggle_link = 0.0138
41-
reset_on_extinction = 1
42-
43-
[genotype compatibility]
11+
num_inputs = 2
12+
num_hidden = 0
13+
num_outputs = 1
14+
initial_connection = full
15+
feed_forward = 0
16+
# genome compatibility options
4417
compatibility_threshold = 3.0
45-
excess_coefficient = 1.0
4618
disjoint_coefficient = 1.0
47-
weight_coefficient = 0.4
19+
weight_coefficient = 0.5
20+
# connection add/remove rates
21+
conn_add_prob = 0.5
22+
conn_delete_prob = 0.25
23+
# node add/remove rates
24+
node_add_prob = 0.1
25+
node_delete_prob = 0.05
26+
# node activation options
27+
activation_default = sigmoid
28+
activation_options = sigmoid
29+
activation_mutate_rate = 0.0
30+
# node aggregation options
31+
aggregation_default = sum
32+
aggregation_options = sum
33+
aggregation_mutate_rate = 0.0
34+
# node bias options
35+
bias_init_mean = 0.0
36+
bias_init_stdev = 1.0
37+
bias_replace_rate = 0.1
38+
bias_mutate_rate = 0.7
39+
bias_mutate_power = 0.5
40+
bias_max_value = 30.0
41+
bias_min_value = -30.0
42+
# node response options
43+
response_init_mean = 5.0
44+
response_init_stdev = 0.1
45+
response_replace_rate = 0.1
46+
response_mutate_rate = 0.2
47+
response_mutate_power = 0.1
48+
response_max_value = 30.0
49+
response_min_value = -30.0
50+
# connection weight options
51+
weight_max_value = 30
52+
weight_min_value = -30
53+
weight_init_mean = 0.0
54+
weight_init_stdev = 1.0
55+
weight_mutate_rate = 0.8
56+
weight_replace_rate = 0.1
57+
weight_mutate_power = 0.5
58+
# connection enable options
59+
enabled_default = True
60+
enabled_mutate_rate = 0.01
4861

4962
[DefaultStagnation]
50-
species_fitness_func = mean
51-
max_stagnation = 15
63+
species_fitness = median
64+
max_stagnation = 10
5265

5366
[DefaultReproduction]
54-
elitism = 1
55-
survival_threshold = 0.2
67+
elitism = 2
68+
survival_threshold = 0.2
69+

neat/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,15 @@
11
from neat.sequential_static import SequentialStatic
22

3+
import neat.nn as nn
4+
import neat.iznn as iznn
5+
6+
from neat.config import Config
7+
from neat.population import Population
8+
from neat.genome import DefaultGenome
9+
from neat.reproduction import DefaultReproduction
10+
from neat.stagnation import DefaultStagnation
11+
from neat.reporting import StdOutReporter
12+
from neat.statistics import StatisticsReporter
13+
14+
15+

neat/config.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,7 @@ class Config(object):
7171

7272
__params = [ConfigParameter('pop_size', int),
7373
ConfigParameter('max_fitness_threshold', float),
74-
ConfigParameter('reset_on_extinction', bool),
75-
ConfigParameter('collect_statistics', bool),
76-
ConfigParameter('report', bool),
77-
ConfigParameter('save_best', bool)]
74+
ConfigParameter('reset_on_extinction', bool)]
7875

7976
def __init__(self, genome_type, reproduction_type, stagnation_type, filename):
8077
# Check that the provided types have the required methods.
@@ -113,12 +110,6 @@ def __init__(self, genome_type, reproduction_type, stagnation_type, filename):
113110
reproduction_dict = dict(parameters.items(reproduction_type.__name__))
114111
self.reproduction_config = reproduction_type.parse_config(reproduction_dict)
115112

116-
# Time in minutes between saving checkpoints, None for no timed checkpoints.
117-
self.checkpoint_time_interval = None
118-
119-
# Time in generations between saving checkpoints, None for no generational checkpoints.
120-
self.checkpoint_gen_interval = None
121-
122113
def save(self, filename):
123114
with open(filename, 'w') as f:
124115
f.write('# The `NEAT` section specifies parameters particular to the NEAT algorithm\n')

neat/genes.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
# TODO: There is probably a lot of room for simplification of these classes using metaprogramming.
55
# TODO: Evaluate using __slots__ for performance/memory usage improvement.
66

7-
87
class BaseGene(object):
98
def __init__(self, key):
109
self.key = key
1110

1211
def __str__(self):
1312
attrib = ['key'] + [a.name for a in self.__gene_attributes__]
1413
attrib = ['{0}={1}'.format(a, getattr(self, a)) for a in attrib]
15-
return '{0}({1})'.format(__class__.__name__, "".join(attrib))
14+
return '{0}({1})'.format(__class__.__name__, ", ".join(attrib))
1615

1716
def __lt__(self, other):
1817
return self.key < other.key
@@ -62,6 +61,7 @@ def crossover(self, gene2):
6261

6362
# TODO: Create some kind of aggregated config object that can replace
6463
# most of DefaultGeneConfig and genome.DefaultGenomeConfig?
64+
# TODO: Should these be in the nn module? iznn and ctrnn can have additional attributes.
6565

6666
class DefaultGeneConfig(object):
6767
def __init__(self, attribs, params):
@@ -88,8 +88,11 @@ class DefaultNodeGene(BaseGene):
8888
def parse_config(cls, config, param_dict):
8989
return DefaultGeneConfig(cls.__gene_attributes__, param_dict)
9090

91-
def distance(self, other):
92-
raise NotImplementedError()
91+
def distance(self, other, config):
92+
d = abs(self.bias - other.bias) + abs(self.response - other.response)
93+
if self.activation != other.activation:
94+
d += 1.0
95+
return d * config.weight_coefficient
9396

9497

9598
# TODO: Do an ablation study to determine whether the enabled setting is
@@ -103,5 +106,9 @@ class DefaultConnectionGene(BaseGene):
103106
def parse_config(cls, config, param_dict):
104107
return DefaultGeneConfig(cls.__gene_attributes__, param_dict)
105108

106-
def distance(self, other):
107-
raise NotImplementedError()
109+
def distance(self, other, config):
110+
d = abs(self.weight - other.weight)
111+
if self.enabled != other.enabled:
112+
d += 1.0
113+
return d * config.weight_coefficient
114+

0 commit comments

Comments
 (0)