Skip to content

Commit 511bbc7

Browse files
Adam KleczewskiTomDLT
authored andcommitted
[MRG+1] Chassifier chain example fix (scikit-learn#9408)
1 parent 11e7369 commit 511bbc7

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

examples/multioutput/plot_classifier_chain_yeast.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
Example of using classifier chain on a multilabel dataset.
66
77
For this example we will use the `yeast
8-
<http://mldata.org/repository/data/viewslug/yeast>`_ dataset which
9-
contains 2417 datapoints each with 103 features and 14 possible labels. Each
10-
datapoint has at least one label. As a baseline we first train a logistic
11-
regression classifier for each of the 14 labels. To evaluate the performance
12-
of these classifiers we predict on a held-out test set and calculate the
13-
:ref:`User Guide <jaccard_similarity_score>`.
8+
<http://mldata.org/repository/data/viewslug/yeast>`_ dataset which contains
9+
2417 datapoints each with 103 features and 14 possible labels. Each
10+
data point has at least one label. As a baseline we first train a logistic
11+
regression classifier for each of the 14 labels. To evaluate the performance of
12+
these classifiers we predict on a held-out test set and calculate the
13+
:ref:`jaccard similarity score <jaccard_similarity_score>`.
1414
1515
Next we create 10 classifier chains. Each classifier chain contains a
1616
logistic regression model for each of the 14 labels. The models in each
@@ -79,7 +79,7 @@
7979
model_scores = [ovr_jaccard_score] + chain_jaccard_scores
8080
model_scores.append(ensemble_jaccard_score)
8181

82-
model_names = ('Independent Models',
82+
model_names = ('Independent',
8383
'Chain 1',
8484
'Chain 2',
8585
'Chain 3',
@@ -90,21 +90,22 @@
9090
'Chain 8',
9191
'Chain 9',
9292
'Chain 10',
93-
'Ensemble Average')
93+
'Ensemble')
9494

95-
y_pos = np.arange(len(model_names))
96-
y_pos[1:] += 1
97-
y_pos[-1] += 1
95+
x_pos = np.arange(len(model_names))
9896

9997
# Plot the Jaccard similarity scores for the independent model, each of the
10098
# chains, and the ensemble (note that the vertical axis on this plot does
10199
# not begin at 0).
102100

103-
fig = plt.figure(figsize=(7, 4))
104-
plt.title('Classifier Chain Ensemble')
105-
plt.xticks(y_pos, model_names, rotation='vertical')
106-
plt.ylabel('Jaccard Similarity Score')
107-
plt.ylim([min(model_scores) * .9, max(model_scores) * 1.1])
101+
fig, ax = plt.subplots(figsize=(7, 4))
102+
ax.grid(True)
103+
ax.set_title('Classifier Chain Ensemble Performance Comparison')
104+
ax.set_xticks(x_pos)
105+
ax.set_xticklabels(model_names, rotation='vertical')
106+
ax.set_ylabel('Jaccard Similarity Score')
107+
ax.set_ylim([min(model_scores) * .9, max(model_scores) * 1.1])
108108
colors = ['r'] + ['b'] * len(chain_jaccard_scores) + ['g']
109-
plt.bar(y_pos, model_scores, align='center', alpha=0.5, color=colors)
109+
ax.bar(x_pos, model_scores, alpha=0.5, color=colors)
110+
plt.tight_layout()
110111
plt.show()

0 commit comments

Comments
 (0)