Skip to content

Commit 65bbe9e

Browse files
committed
plot_adaboost_multiclass.py: handle case where boosting terminated early. Add missing author on other boosting examples.
1 parent a32eb88 commit 65bbe9e

File tree

3 files changed

+38
-18
lines changed

3 files changed

+38
-18
lines changed

examples/ensemble/plot_adaboost_multiclass.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,37 +72,49 @@
7272
discrete_test_errors.append(
7373
1. - accuracy_score(discrete_train_predict, y_test))
7474

75-
n_trees = xrange(1, len(bdt_discrete) + 1)
75+
n_trees_discrete = len(bdt_discrete)
76+
n_trees_real = len(bdt_real)
77+
78+
# Boosting might terminate early, but the following arrays are always
79+
# n_estimators long. We crop them to the actual number of trees here:
80+
discrete_estimator_errors = bdt_discrete.estimator_errors_[:n_trees_discrete]
81+
real_estimator_errors = bdt_real.estimator_errors_[:n_trees_real]
82+
discrete_estimator_weights = bdt_discrete.estimator_weights_[:n_trees_discrete]
7683

7784
plt.figure(figsize=(15, 5))
7885

7986
plt.subplot(131)
80-
plt.plot(n_trees, discrete_test_errors, c='black', label='SAMME')
81-
plt.plot(n_trees, real_test_errors, c='black',
82-
linestyle='dashed', label='SAMME.R')
87+
plt.plot(xrange(1, n_trees_discrete + 1),
88+
discrete_test_errors, c='black', label='SAMME')
89+
plt.plot(xrange(1, n_trees_real + 1),
90+
real_test_errors, c='black',
91+
linestyle='dashed', label='SAMME.R')
8392
plt.legend()
8493
plt.ylim(0.18, 0.62)
8594
plt.ylabel('Test Error')
8695
plt.xlabel('Number of Trees')
8796

8897
plt.subplot(132)
89-
plt.plot(n_trees, bdt_discrete.estimator_errors_, "b", label='SAMME', alpha=.5)
90-
plt.plot(n_trees, bdt_real.estimator_errors_, "r", label='SAMME.R', alpha=.5)
98+
plt.plot(xrange(1, n_trees_discrete + 1), discrete_estimator_errors,
99+
"b", label='SAMME', alpha=.5)
100+
plt.plot(xrange(1, n_trees_real + 1), real_estimator_errors,
101+
"r", label='SAMME.R', alpha=.5)
91102
plt.legend()
92103
plt.ylabel('Error')
93104
plt.xlabel('Number of Trees')
94105
plt.ylim((.2,
95-
max(bdt_real.estimator_errors_.max(),
96-
bdt_discrete.estimator_errors_.max()) * 1.2))
106+
max(real_estimator_errors.max(),
107+
discrete_estimator_errors.max()) * 1.2))
97108
plt.xlim((-20, len(bdt_discrete) + 20))
98109

99110
plt.subplot(133)
100-
plt.plot(n_trees, bdt_discrete.estimator_weights_, "b", label='SAMME')
111+
plt.plot(xrange(1, n_trees_discrete + 1), discrete_estimator_weights,
112+
"b", label='SAMME')
101113
plt.legend()
102114
plt.ylabel('Weight')
103115
plt.xlabel('Number of Trees')
104-
plt.ylim((0, bdt_discrete.estimator_weights_.max() * 1.2))
105-
plt.xlim((-20, len(bdt_discrete) + 20))
116+
plt.ylim((0, discrete_estimator_weights.max() * 1.2))
117+
plt.xlim((-20, n_trees_discrete + 20))
106118

107119
# prevent overlapping y-axis labels
108120
plt.subplots_adjust(wspace=0.25)

examples/ensemble/plot_adaboost_regression.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
"""
1515
print(__doc__)
1616

17+
# Author: Noel Dawe <[email protected]>
18+
#
19+
# License: BSD 3 clause
20+
1721
import numpy as np
1822
import matplotlib.pyplot as plt
1923

examples/ensemble/plot_adaboost_twoclass.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
"""
1919
print(__doc__)
2020

21+
# Author: Noel Dawe <[email protected]>
22+
#
23+
# License: BSD 3 clause
24+
2125
import numpy as np
2226
import matplotlib.pyplot as plt
2327

@@ -65,8 +69,8 @@
6569
for i, n, c in zip(range(2), class_names, plot_colors):
6670
idx = np.where(y == i)
6771
plt.scatter(X[idx, 0], X[idx, 1],
68-
c=c, cmap=plt.cm.Paired,
69-
label="Class %s" % n)
72+
c=c, cmap=plt.cm.Paired,
73+
label="Class %s" % n)
7074
plt.xlim(x_min, x_max)
7175
plt.ylim(y_min, y_max)
7276
plt.legend(loc='upper right')
@@ -78,11 +82,11 @@
7882
plt.subplot(122)
7983
for i, n, c in zip(range(2), class_names, plot_colors):
8084
plt.hist(twoclass_output[y == i],
81-
bins=10,
82-
range=plot_range,
83-
facecolor=c,
84-
label='Class %s' % n,
85-
alpha=.5)
85+
bins=10,
86+
range=plot_range,
87+
facecolor=c,
88+
label='Class %s' % n,
89+
alpha=.5)
8690
x1, x2, y1, y2 = plt.axis()
8791
plt.axis((x1, x2, y1, y2 * 1.2))
8892
plt.legend(loc='upper right')

0 commit comments

Comments
 (0)