Skip to content

Commit 2221952

Browse files
authored
Merge pull request yunjey#120 from arisliang/master
minor refactoring and fix
2 parents 2537490 + e217e42 commit 2221952

File tree

3 files changed

+7
-5
lines changed
  • tutorials
    • 01-basics/logistic_regression
    • 02-intermediate

3 files changed

+7
-5
lines changed

tutorials/01-basics/logistic_regression/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
# Hyper-parameters
8-
input_size = 784
8+
input_size = 28 * 28 # 784
99
num_classes = 10
1010
num_epochs = 5
1111
batch_size = 100
@@ -43,7 +43,7 @@
4343
for epoch in range(num_epochs):
4444
for i, (images, labels) in enumerate(train_loader):
4545
# Reshape images to (batch_size, input_size)
46-
images = images.reshape(-1, 28*28)
46+
images = images.reshape(-1, input_size)
4747

4848
# Forward pass
4949
outputs = model(images)
@@ -64,7 +64,7 @@
6464
correct = 0
6565
total = 0
6666
for images, labels in test_loader:
67-
images = images.reshape(-1, 28*28)
67+
images = images.reshape(-1, input_size)
6868
outputs = model(images)
6969
_, predicted = torch.max(outputs.data, 1)
7070
total += labels.size(0)

tutorials/02-intermediate/deep_residual_network/main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# Hyper-parameters
1818
num_epochs = 80
19+
batch_size = 100
1920
learning_rate = 0.001
2021

2122
# Image preprocessing modules
@@ -37,11 +38,11 @@
3738

3839
# Data loader
3940
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
40-
batch_size=100,
41+
batch_size=batch_size,
4142
shuffle=True)
4243

4344
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
44-
batch_size=100,
45+
batch_size=batch_size,
4546
shuffle=False)
4647

4748
# 3x3 convolution

tutorials/02-intermediate/recurrent_neural_network/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def forward(self, x):
8585
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
8686

8787
# Test the model
88+
model.eval()
8889
with torch.no_grad():
8990
correct = 0
9091
total = 0

0 commit comments

Comments
 (0)