Skip to content

Commit f791d51

Browse files
authored
Add files via upload
1 parent a2eea76 commit f791d51

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

models/logisticRegression.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Logistic Regression
33
author: Ye Hu
4-
2016/12/14
4+
2016/12/14 update 2017/02/16
55
"""
66
import numpy as np
77
import tensorflow as tf
@@ -42,30 +42,30 @@ def accuarcy(self, y):
4242

4343

4444
if __name__ == "__main__":
45-
# 导入数据
45+
# Load mnist dataset
4646
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
47-
# 定义输入输出Tensor
47+
# Define placeholder for input and target
4848
x = tf.placeholder(tf.float32, shape=[None, 784])
4949
y_ = tf.placeholder(tf.float32, shape=[None, 10])
5050

51-
# 定义分类器
51+
# Construct model
5252
classifier = LogisticRegression(x, n_in=784, n_out=10)
5353
cost = classifier.cost(y_)
5454
accuracy = classifier.accuarcy(y_)
5555
predictor = classifier.y_pred
56-
# 定义训练器
56+
# Define the train operation
5757
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(
5858
cost, var_list=classifier.params)
5959

60-
# 初始化所有变量
60+
# Initialize all variables
6161
init = tf.global_variables_initializer()
6262

63-
# 定义训练参数
63+
# Training settings
6464
training_epochs = 50
6565
batch_size = 100
6666
display_step = 5
6767

68-
# 开始训练
68+
# Train loop
6969
print("Start to train...")
7070
with tf.Session() as sess:
7171
sess.run(init)
@@ -74,11 +74,11 @@ def accuarcy(self, y):
7474
batch_num = int(mnist.train.num_examples/batch_size)
7575
for i in range(batch_num):
7676
x_batch, y_batch = mnist.train.next_batch(batch_size)
77-
# 训练
78-
sess.run(train_op, feed_dict={x: x_batch, y_: y_batch})
79-
# 计算cost
80-
avg_cost += sess.run(cost, feed_dict={x: x_batch, y_: y_batch})/batch_num
81-
# 输出
77+
# Run train op
78+
c, _ = sess.run([cost, train_op], feed_dict={x: x_batch, y_: y_batch})
79+
# Sum up cost
80+
avg_cost += c/batch_num
81+
8282
if epoch % display_step == 0:
8383
val_acc = sess.run(accuracy, feed_dict={x: mnist.validation.images,
8484
y_: mnist.validation.labels})

0 commit comments

Comments
 (0)