Skip to content

Commit 2285ba0

Browse files
rossbarbsipocz
andauthored
Fix bug and improve computation / display of metrics for MNIST tutorial (#189)
* BUG: Fix incorrect variable in computing eval metrics. * ENH: Replace list comps with vectorization. * ENH: Use dicts and condense plotting. * Update content/tutorial-deep-learning-on-mnist.md Co-authored-by: Brigitta Sipőcz <[email protected]>
1 parent 73ebebf commit 2285ba0

File tree

1 file changed

+19
-27
lines changed

1 file changed

+19
-27
lines changed

content/tutorial-deep-learning-on-mnist.md

+19-27
Original file line numberDiff line numberDiff line change
@@ -561,39 +561,31 @@ The training process may take many minutes, depending on a number of factors, su
561561
After executing the cell above, you can visualize the training and test set errors and accuracy for an instance of this training process.
562562

563563
```{code-cell}
564+
epoch_range = np.arange(epochs) + 1 # Starting from 1
565+
564566
# The training set metrics.
565-
y_training_error = [
566-
store_training_loss[i] / float(len(training_images))
567-
for i in range(len(store_training_loss))
568-
]
569-
x_training_error = range(1, len(store_training_loss) + 1)
570-
y_training_accuracy = [
571-
store_training_accurate_pred[i] / float(len(training_images))
572-
for i in range(len(store_training_accurate_pred))
573-
]
574-
x_training_accuracy = range(1, len(store_training_accurate_pred) + 1)
567+
training_metrics = {
568+
"accuracy": np.asarray(store_training_accurate_pred) / len(training_images),
569+
"error": np.asarray(store_training_loss) / len(training_images),
570+
}
575571
576572
# The test set metrics.
577-
y_test_error = [
578-
store_test_loss[i] / float(len(test_images)) for i in range(len(store_test_loss))
579-
]
580-
x_test_error = range(1, len(store_test_loss) + 1)
581-
y_test_accuracy = [
582-
store_training_accurate_pred[i] / float(len(training_images))
583-
for i in range(len(store_training_accurate_pred))
584-
]
585-
x_test_accuracy = range(1, len(store_test_accurate_pred) + 1)
573+
test_metrics = {
574+
"accuracy": np.asarray(store_test_accurate_pred) / len(test_images),
575+
"error": np.asarray(store_test_loss) / len(test_images),
576+
}
586577
587578
# Display the plots.
588579
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
589-
axes[0].set_title("Training set error, accuracy")
590-
axes[0].plot(x_training_accuracy, y_training_accuracy, label="Training set accuracy")
591-
axes[0].plot(x_training_error, y_training_error, label="Training set error")
592-
axes[0].set_xlabel("Epochs")
593-
axes[1].set_title("Test set error, accuracy")
594-
axes[1].plot(x_test_accuracy, y_test_accuracy, label="Test set accuracy")
595-
axes[1].plot(x_test_error, y_test_error, label="Test set error")
596-
axes[1].set_xlabel("Epochs")
580+
for ax, metrics, title in zip(
581+
axes, (training_metrics, test_metrics), ("Training set", "Test set")
582+
):
583+
# Plot the metrics
584+
for metric, values in metrics.items():
585+
ax.plot(epoch_range, values, label=metric.capitalize())
586+
ax.set_title(title)
587+
ax.set_xlabel("Epochs")
588+
ax.legend()
597589
plt.show()
598590
```
599591

0 commit comments

Comments
 (0)