Skip to content

Commit c255a78

Browse files
committed
ENH: classification_report format supports long string labels
1 parent 21d9ccc commit c255a78

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

sklearn/metrics/classification.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
from scipy.sparse import coo_matrix
2929
from scipy.sparse import csr_matrix
30-
from scipy.spatial.distance import hamming as sp_hamming
3130

3231
from ..preprocessing import LabelBinarizer, label_binarize
3332
from ..preprocessing import LabelEncoder
@@ -640,7 +639,8 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary',
640639
641640
References
642641
----------
643-
.. [1] `Wikipedia entry for the F1-score <http://en.wikipedia.org/wiki/F1_score>`_
642+
.. [1] `Wikipedia entry for the F1-score
643+
<http://en.wikipedia.org/wiki/F1_score>`_
644644
645645
Examples
646646
--------
@@ -1386,11 +1386,9 @@ class 2 1.00 0.67 0.80 3
13861386
last_line_heading = 'avg / total'
13871387

13881388
if target_names is None:
1389-
width = len(last_line_heading)
13901389
target_names = ['%s' % l for l in labels]
1391-
else:
1392-
width = max(len(cn) for cn in target_names)
1393-
width = max(width, len(last_line_heading), digits)
1390+
name_width = max(len(cn) for cn in target_names)
1391+
width = max(name_width, len(last_line_heading), digits)
13941392

13951393
headers = ["precision", "recall", "f1-score", "support"]
13961394
fmt = '%% %ds' % width # first column: class name
@@ -1508,8 +1506,10 @@ def hamming_loss(y_true, y_pred, classes=None, sample_weight=None):
15081506
weight_average = np.mean(sample_weight)
15091507

15101508
if y_type.startswith('multilabel'):
1511-
n_differences = count_nonzero(y_true - y_pred, sample_weight=sample_weight)
1512-
return (n_differences / (y_true.shape[0] * len(classes) * weight_average))
1509+
n_differences = count_nonzero(y_true - y_pred,
1510+
sample_weight=sample_weight)
1511+
return (n_differences /
1512+
(y_true.shape[0] * len(classes) * weight_average))
15131513

15141514
elif y_type in ["binary", "multiclass"]:
15151515
return _weighted_sum(y_true != y_pred, sample_weight, normalize=True)

sklearn/metrics/tests/test_classification.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,27 @@ def test_classification_report_multiclass_with_unicode_label():
669669
assert_equal(report, expected_report)
670670

671671

672+
def test_classification_report_multiclass_with_long_string_label():
673+
y_true, y_pred, _ = make_prediction(binary=False)
674+
675+
labels = np.array(["blue", "green"*5, "red"])
676+
y_true = labels[y_true]
677+
y_pred = labels[y_pred]
678+
679+
expected_report = """\
680+
precision recall f1-score support
681+
682+
blue 0.83 0.79 0.81 24
683+
greengreengreengreengreen 0.33 0.10 0.15 31
684+
red 0.42 0.90 0.57 20
685+
686+
avg / total 0.51 0.53 0.47 75
687+
"""
688+
689+
report = classification_report(y_true, y_pred)
690+
assert_equal(report, expected_report)
691+
692+
672693
def test_multilabel_classification_report():
673694
n_classes = 4
674695
n_samples = 50

0 commit comments

Comments
 (0)