Skip to content

Commit 156673a

Browse files
committed
Minor refactoring of network_basic
1 parent 7460477 commit 156673a

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

code/network_basic.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
network_basic
3-
~~~~~~~~~~~~~~
2+
network_basic
3+
~~~~~~~~~~~~~
44
55
A module to implement the stochastic gradient descent learning
66
algorithm for a feedforward neural network. Gradients are calculated
@@ -42,18 +42,16 @@ def feedforward(self, a):
4242
return a
4343

4444
def SGD(self, training_data, epochs, mini_batch_size, eta,
45-
lmbda, test=False, test_data=None):
45+
lmbda, test_data=None):
4646
"""Train the neural network using mini-batch stochastic
4747
gradient descent. The ``training_data`` is a list of tuples
4848
``(x, y)`` representing the training inputs and the desired
4949
outputs. The other non-optional parameters are
50-
self-explanatory. Set ``test`` to ``True`` to evaluate the
51-
network against the test data after each epoch, and to print
52-
out partial progress. This is useful for tracking progress,
53-
but slows things down substantially. If ``test`` is set, then
54-
appropriate ``test_data`` must be supplied.
55-
"""
56-
if test: n_test = len(test_data)
50+
self-explanatory. If ``test_data`` is provided then the
51+
network will be evaluated against the test data after each
52+
epoch, and partial progress printed out. This is useful for
53+
tracking progress, but slows things down substantially."""
54+
if test_data: n_test = len(test_data)
5755
n = len(training_data)
5856
for j in xrange(epochs):
5957
random.shuffle(training_data)
@@ -62,7 +60,7 @@ def SGD(self, training_data, epochs, mini_batch_size, eta,
6260
for k in xrange(0, n, mini_batch_size)]
6361
for mini_batch in mini_batches:
6462
self.backprop(mini_batch, n, eta, lmbda)
65-
if test:
63+
if test_data:
6664
print "Epoch {}: {} / {}".format(
6765
j, self.evaluate(test_data), n_test)
6866
else:

0 commit comments

Comments
 (0)