We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9927173 commit 3008d88Copy full SHA for 3008d88
tutorials/01-basics/logistic_regression/main.py
@@ -5,7 +5,7 @@
5
6
7
# Hyper-parameters
8
-input_size = 784
+input_size = 28 * 28 # 784
9
num_classes = 10
10
num_epochs = 5
11
batch_size = 100
@@ -43,7 +43,7 @@
43
for epoch in range(num_epochs):
44
for i, (images, labels) in enumerate(train_loader):
45
# Reshape images to (batch_size, input_size)
46
- images = images.reshape(-1, 28*28)
+ images = images.reshape(-1, input_size)
47
48
# Forward pass
49
outputs = model(images)
@@ -64,7 +64,7 @@
64
correct = 0
65
total = 0
66
for images, labels in test_loader:
67
68
69
_, predicted = torch.max(outputs.data, 1)
70
total += labels.size(0)
0 commit comments