Skip to content

Commit ca9da53

Browse files
committed
Added deep_learning module
1 parent b5c9464 commit ca9da53

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

code/deep_learning.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""
2+
deep_learning
3+
~~~~~~~~~~~~~
4+
5+
Module to do deep learning. Most of the functionality needed is
6+
already in the ``backprop2`` and ``deep_autoencoder`` modules, but
7+
this adds convenience functions to help in doing things like unrolling
8+
deep autoencoders, and adding and training a classifier layer."""
9+
10+
# My Libraries
11+
from backprop2 import Network
12+
from deep_autoencoder import DeepAutoencoder
13+
14+
def unroll(deep_autoencoder):
15+
"""
16+
Return a Network that contains the compression stage of the
17+
``deep_autoencoder``."""
18+
net = Network(deep_autoencoder.layers)
19+
net.weights = deep_autoencoder.weights[:len(deep_autoencoder.layers)-1]
20+
net.biases = deep_autoencoder.biases[:len(deep_autoencoder.layers)-1]
21+
return net
22+
23+
def add_classifier_layer(net, num_outputs):
24+
"""
25+
Return the Network ``net``, but with an extra layer containing
26+
``num_outputs`` neurons appended."""
27+
net_classifier = Network(net.sizes+[num_outputs])
28+
net_classifier.weights[:-1] = net.weights
29+
net_classifier.biases[:-1] = net.biases
30+
return net_classifier
31+
32+
def SGD_final_layer(
33+
self, training_data, epochs, mini_batch_size, eta, lmbda,
34+
test=False, test_inputs=None, actual_test_results=None):
35+
"""
36+
Run SGD on the final layer of the Network ``self``. Note that
37+
``training_data`` is the input to the whole Network, not the
38+
encoded training data input to the final layer.
39+
"""
40+
net = Network([self.sizes[-2], self.sizes[-1]])
41+
encoded_training_data = self.feedforward(training_data, start=0, end=-1)
42+
net.SGD(encoded_training_data, epochs, mini_batch_size, eta,
43+
lmbda, test, test_inputs, actual_test_results)
44+
45+
# Add the SGD_final_layer method to the Network class
46+
Network.SGD_final_layer = SGD_final_layer

0 commit comments

Comments
 (0)