@@ -561,39 +561,31 @@ The training process may take many minutes, depending on a number of factors, su
561
561
After executing the cell above, you can visualize the training and test set errors and accuracy for an instance of this training process.
562
562
563
563
``` {code-cell}
564
+ epoch_range = np.arange(epochs) + 1 # Starting from 1
565
+
564
566
# The training set metrics.
565
- y_training_error = [
566
- store_training_loss[i] / float(len(training_images))
567
- for i in range(len(store_training_loss))
568
- ]
569
- x_training_error = range(1, len(store_training_loss) + 1)
570
- y_training_accuracy = [
571
- store_training_accurate_pred[i] / float(len(training_images))
572
- for i in range(len(store_training_accurate_pred))
573
- ]
574
- x_training_accuracy = range(1, len(store_training_accurate_pred) + 1)
567
+ training_metrics = {
568
+ "accuracy": np.asarray(store_training_accurate_pred) / len(training_images),
569
+ "error": np.asarray(store_training_loss) / len(training_images),
570
+ }
575
571
576
572
# The test set metrics.
577
- y_test_error = [
578
- store_test_loss[i] / float(len(test_images)) for i in range(len(store_test_loss))
579
- ]
580
- x_test_error = range(1, len(store_test_loss) + 1)
581
- y_test_accuracy = [
582
- store_training_accurate_pred[i] / float(len(training_images))
583
- for i in range(len(store_training_accurate_pred))
584
- ]
585
- x_test_accuracy = range(1, len(store_test_accurate_pred) + 1)
573
+ test_metrics = {
574
+ "accuracy": np.asarray(store_test_accurate_pred) / len(test_images),
575
+ "error": np.asarray(store_test_loss) / len(test_images),
576
+ }
586
577
587
578
# Display the plots.
588
579
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
589
- axes[0].set_title("Training set error, accuracy")
590
- axes[0].plot(x_training_accuracy, y_training_accuracy, label="Training set accuracy")
591
- axes[0].plot(x_training_error, y_training_error, label="Training set error")
592
- axes[0].set_xlabel("Epochs")
593
- axes[1].set_title("Test set error, accuracy")
594
- axes[1].plot(x_test_accuracy, y_test_accuracy, label="Test set accuracy")
595
- axes[1].plot(x_test_error, y_test_error, label="Test set error")
596
- axes[1].set_xlabel("Epochs")
580
+ for ax, metrics, title in zip(
581
+ axes, (training_metrics, test_metrics), ("Training set", "Test set")
582
+ ):
583
+ # Plot the metrics
584
+ for metric, values in metrics.items():
585
+ ax.plot(epoch_range, values, label=metric.capitalize())
586
+ ax.set_title(title)
587
+ ax.set_xlabel("Epochs")
588
+ ax.legend()
597
589
plt.show()
598
590
```
599
591
0 commit comments