|
72 | 72 | discrete_test_errors.append( |
73 | 73 | 1. - accuracy_score(discrete_train_predict, y_test)) |
74 | 74 |
|
75 | | -n_trees = xrange(1, len(bdt_discrete) + 1) |
| 75 | +n_trees_discrete = len(bdt_discrete) |
| 76 | +n_trees_real = len(bdt_real) |
| 77 | + |
| 78 | +# Boosting might terminate early, but the following arrays are always |
| 79 | +# n_estimators long. We crop them to the actual number of trees here: |
| 80 | +discrete_estimator_errors = bdt_discrete.estimator_errors_[:n_trees_discrete] |
| 81 | +real_estimator_errors = bdt_real.estimator_errors_[:n_trees_real] |
| 82 | +discrete_estimator_weights = bdt_discrete.estimator_weights_[:n_trees_discrete] |
76 | 83 |
|
77 | 84 | plt.figure(figsize=(15, 5)) |
78 | 85 |
|
79 | 86 | plt.subplot(131) |
80 | | -plt.plot(n_trees, discrete_test_errors, c='black', label='SAMME') |
81 | | -plt.plot(n_trees, real_test_errors, c='black', |
82 | | - linestyle='dashed', label='SAMME.R') |
| 87 | +plt.plot(xrange(1, n_trees_discrete + 1), |
| 88 | + discrete_test_errors, c='black', label='SAMME') |
| 89 | +plt.plot(xrange(1, n_trees_real + 1), |
| 90 | + real_test_errors, c='black', |
| 91 | + linestyle='dashed', label='SAMME.R') |
83 | 92 | plt.legend() |
84 | 93 | plt.ylim(0.18, 0.62) |
85 | 94 | plt.ylabel('Test Error') |
86 | 95 | plt.xlabel('Number of Trees') |
87 | 96 |
|
88 | 97 | plt.subplot(132) |
89 | | -plt.plot(n_trees, bdt_discrete.estimator_errors_, "b", label='SAMME', alpha=.5) |
90 | | -plt.plot(n_trees, bdt_real.estimator_errors_, "r", label='SAMME.R', alpha=.5) |
| 98 | +plt.plot(xrange(1, n_trees_discrete + 1), discrete_estimator_errors, |
| 99 | + "b", label='SAMME', alpha=.5) |
| 100 | +plt.plot(xrange(1, n_trees_real + 1), real_estimator_errors, |
| 101 | + "r", label='SAMME.R', alpha=.5) |
91 | 102 | plt.legend() |
92 | 103 | plt.ylabel('Error') |
93 | 104 | plt.xlabel('Number of Trees') |
94 | 105 | plt.ylim((.2, |
95 | | - max(bdt_real.estimator_errors_.max(), |
96 | | - bdt_discrete.estimator_errors_.max()) * 1.2)) |
| 106 | + max(real_estimator_errors.max(), |
| 107 | + discrete_estimator_errors.max()) * 1.2)) |
97 | 108 | plt.xlim((-20, len(bdt_discrete) + 20)) |
98 | 109 |
|
99 | 110 | plt.subplot(133) |
100 | | -plt.plot(n_trees, bdt_discrete.estimator_weights_, "b", label='SAMME') |
| 111 | +plt.plot(xrange(1, n_trees_discrete + 1), discrete_estimator_weights, |
| 112 | + "b", label='SAMME') |
101 | 113 | plt.legend() |
102 | 114 | plt.ylabel('Weight') |
103 | 115 | plt.xlabel('Number of Trees') |
104 | | -plt.ylim((0, bdt_discrete.estimator_weights_.max() * 1.2)) |
105 | | -plt.xlim((-20, len(bdt_discrete) + 20)) |
| 116 | +plt.ylim((0, discrete_estimator_weights.max() * 1.2)) |
| 117 | +plt.xlim((-20, n_trees_discrete + 20)) |
106 | 118 |
|
107 | 119 | # prevent overlapping y-axis labels |
108 | 120 | plt.subplots_adjust(wspace=0.25) |
|
0 commit comments