Skip to content

Commit 3008d88

Browse files
committed
minor refactoring for input size
1 parent 9927173 commit 3008d88

File tree

1 file changed

+3
-3
lines changed
  • tutorials/01-basics/logistic_regression

1 file changed

+3
-3
lines changed

tutorials/01-basics/logistic_regression/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
# Hyper-parameters
8-
input_size = 784
8+
input_size = 28 * 28 # 784
99
num_classes = 10
1010
num_epochs = 5
1111
batch_size = 100
@@ -43,7 +43,7 @@
4343
for epoch in range(num_epochs):
4444
for i, (images, labels) in enumerate(train_loader):
4545
# Reshape images to (batch_size, input_size)
46-
images = images.reshape(-1, 28*28)
46+
images = images.reshape(-1, input_size)
4747

4848
# Forward pass
4949
outputs = model(images)
@@ -64,7 +64,7 @@
6464
correct = 0
6565
total = 0
6666
for images, labels in test_loader:
67-
images = images.reshape(-1, 28*28)
67+
images = images.reshape(-1, input_size)
6868
outputs = model(images)
6969
_, predicted = torch.max(outputs.data, 1)
7070
total += labels.size(0)

0 commit comments

Comments
 (0)