Skip to content

Commit f11789e

Browse files
committed
add some comments
1 parent ac6d78f commit f11789e

File tree

1 file changed

+11
-6
lines changed
  • pyTorch/BookPyTorch2/chapter4

1 file changed

+11
-6
lines changed

pyTorch/BookPyTorch2/chapter4/ann.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ def sigmod_derivate(x):
1414
return x * (1 - x)
1515
class BPNeuralNetwork:
1616
def __init__(self):
17-
self.input_n = 0
18-
self.hidden_n = 0
19-
self.output_n = 0
17+
self.input_n = 0 # 输入维度
18+
self.hidden_n = 0 # 隐藏层维度,只有一个 hidden layer
19+
self.output_n = 0 # 输出维度
2020
self.input_cells = []
2121
self.hidden_cells = []
2222
self.output_cells = []
@@ -31,14 +31,16 @@ def setup(self,ni,nh,no):
3131
self.output_cells = [1.0] * self.output_n
3232
self.input_weights = make_matrix(self.input_n,self.hidden_n)
3333
self.output_weights = make_matrix(self.hidden_n,self.output_n)
34-
# random activate
34+
# random activate,随机初始化 weight
3535
for i in range(self.input_n):
3636
for h in range(self.hidden_n):
3737
self.input_weights[i][h] = rand(-0.2, 0.2)
3838
for h in range(self.hidden_n):
3939
for o in range(self.output_n):
4040
self.output_weights[h][o] = rand(-2.0, 2.0)
41-
def predict(self,inputs):
41+
42+
# todo 在推理的时候才需要前向传播
43+
def predict(self,inputs): # forward propagation
4244
for i in range(self.input_n - 1):
4345
self.input_cells[i] = inputs[i]
4446
for j in range(self.hidden_n):
@@ -75,9 +77,10 @@ def back_propagate(self,case,label,learn):
7577
for j in range(self.hidden_n):
7678
self.input_weights[i][j] += learn * hidden_deltas[j] * self.input_cells[i]
7779
error = 0
78-
for o in range(len(label)):
80+
for o in range(len(label)): # todo mse error
7981
error += 0.5 * (label[o] - self.output_cells[o]) ** 2
8082
return error
83+
# todo 训练过程从随机初始化的权重开始,直接进行反向传播,不需要前向传播过程
8184
def train(self,cases,labels,limit = 100,learn = 0.05):
8285
for i in range(limit):
8386
error = 0
@@ -96,6 +99,8 @@ def test(self):
9699
labels = [[0], [1], [1], [0]]
97100
self.setup(2, 5, 1)
98101
self.train(cases, labels, 10000, 0.05)
102+
103+
# todo 执行预测(推理),期望结果应该接近于 labels 里的值
99104
for case in cases:
100105
print(self.predict(case))
101106
if __name__ == '__main__':

0 commit comments

Comments
 (0)