Skip to content

Commit 19c8af6

Browse files
jnothmanqinhanmin2014
authored andcommitted
ENH/FIX Replace jaccard_similarity_score by sane jaccard_score (scikit-learn#13151)
1 parent 35b56b1 commit 19c8af6

File tree

11 files changed

+529
-154
lines changed

11 files changed

+529
-154
lines changed

doc/modules/classes.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ details.
844844
metrics.fbeta_score
845845
metrics.hamming_loss
846846
metrics.hinge_loss
847-
metrics.jaccard_similarity_score
847+
metrics.jaccard_score
848848
metrics.log_loss
849849
metrics.matthews_corrcoef
850850
metrics.multilabel_confusion_matrix
@@ -1505,6 +1505,7 @@ To be removed in 0.23
15051505
utils.cpu_count
15061506
utils.delayed
15071507
metrics.calinski_harabaz_score
1508+
metrics.jaccard_similarity_score
15081509
linear_model.logistic_regression_path
15091510

15101511

doc/modules/model_evaluation.rst

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ Scoring Function
7070
'neg_log_loss' :func:`metrics.log_loss` requires ``predict_proba`` support
7171
'precision' etc. :func:`metrics.precision_score` suffixes apply as with 'f1'
7272
'recall' etc. :func:`metrics.recall_score` suffixes apply as with 'f1'
73+
'jaccard' etc. :func:`metrics.jaccard_score` suffixes apply as with 'f1'
7374
'roc_auc' :func:`metrics.roc_auc_score`
7475

7576
**Clustering**
@@ -326,7 +327,7 @@ Some also work in the multilabel case:
326327
f1_score
327328
fbeta_score
328329
hamming_loss
329-
jaccard_similarity_score
330+
jaccard_score
330331
log_loss
331332
multilabel_confusion_matrix
332333
precision_recall_fscore_support
@@ -346,6 +347,8 @@ And some work with binary and multilabel (but not multiclass) problems:
346347
In the following sub-sections, we will describe each of those functions,
347348
preceded by some notes on common API and metric definition.
348349

350+
.. _average:
351+
349352
From binary to multiclass and multilabel
350353
----------------------------------------
351354

@@ -355,8 +358,6 @@ only the positive label is evaluated, assuming by default that the positive
355358
class is labelled ``1`` (though this may be configurable through the
356359
``pos_label`` parameter).
357360

358-
.. _average:
359-
360361
In extending a binary metric to multiclass or multilabel problems, the data
361362
is treated as a collection of binary problems, one for each class.
362363
There are then a number of ways to average binary metric calculations across
@@ -680,43 +681,6 @@ In the multilabel case with binary label indicators: ::
680681
or superset of the true labels will give a Hamming loss between
681682
zero and one, exclusive.
682683

683-
.. _jaccard_similarity_score:
684-
685-
Jaccard similarity coefficient score
686-
-------------------------------------
687-
688-
The :func:`jaccard_similarity_score` function computes the average (default)
689-
or sum of `Jaccard similarity coefficients
690-
<https://en.wikipedia.org/wiki/Jaccard_index>`_, also called the Jaccard index,
691-
between pairs of label sets.
692-
693-
The Jaccard similarity coefficient of the :math:`i`-th samples,
694-
with a ground truth label set :math:`y_i` and predicted label set
695-
:math:`\hat{y}_i`, is defined as
696-
697-
.. math::
698-
699-
J(y_i, \hat{y}_i) = \frac{|y_i \cap \hat{y}_i|}{|y_i \cup \hat{y}_i|}.
700-
701-
In binary and multiclass classification, the Jaccard similarity coefficient
702-
score is equal to the classification accuracy.
703-
704-
::
705-
706-
>>> import numpy as np
707-
>>> from sklearn.metrics import jaccard_similarity_score
708-
>>> y_pred = [0, 2, 1, 3]
709-
>>> y_true = [0, 1, 2, 3]
710-
>>> jaccard_similarity_score(y_true, y_pred)
711-
0.5
712-
>>> jaccard_similarity_score(y_true, y_pred, normalize=False)
713-
2
714-
715-
In the multilabel case with binary label indicators: ::
716-
717-
>>> jaccard_similarity_score(np.array([[0, 1], [1, 1]]), np.ones((2, 2)))
718-
0.75
719-
720684
.. _precision_recall_f_measure_metrics:
721685

722686
Precision, recall and F-measures
@@ -957,6 +921,61 @@ Similarly, labels not present in the data sample may be accounted for in macro-a
957921
... # doctest: +ELLIPSIS
958922
0.166...
959923

924+
.. _jaccard_similarity_score:
925+
926+
Jaccard similarity coefficient score
927+
-------------------------------------
928+
929+
The :func:`jaccard_score` function computes the average of `Jaccard similarity
930+
coefficients <https://en.wikipedia.org/wiki/Jaccard_index>`_, also called the
931+
Jaccard index, between pairs of label sets.
932+
933+
The Jaccard similarity coefficient of the :math:`i`-th samples,
934+
with a ground truth label set :math:`y_i` and predicted label set
935+
:math:`\hat{y}_i`, is defined as
936+
937+
.. math::
938+
939+
J(y_i, \hat{y}_i) = \frac{|y_i \cap \hat{y}_i|}{|y_i \cup \hat{y}_i|}.
940+
941+
:func:`jaccard_score` works like :func:`precision_recall_fscore_support` as a
942+
naively set-wise measure applying natively to binary targets, and extended to
943+
apply to multilabel and multiclass through the use of `average` (see
944+
:ref:`above <average>`).
945+
946+
In the binary case: ::
947+
948+
>>> import numpy as np
949+
>>> from sklearn.metrics import jaccard_score
950+
>>> y_true = np.array([[0, 1, 1],
951+
... [1, 1, 0]])
952+
>>> y_pred = np.array([[1, 1, 1],
953+
... [1, 0, 0]])
954+
>>> jaccard_score(y_true[0], y_pred[0]) # doctest: +ELLIPSIS
955+
0.6666...
956+
957+
In the multilabel case with binary label indicators: ::
958+
959+
>>> jaccard_score(y_true, y_pred, average='samples') # doctest: +ELLIPSIS
960+
0.5833...
961+
>>> jaccard_score(y_true, y_pred, average='macro') # doctest: +ELLIPSIS
962+
0.6666...
963+
>>> jaccard_score(y_true, y_pred, average=None)
964+
array([0.5, 0.5, 1. ])
965+
966+
Multiclass problems are binarized and treated like the corresponding
967+
multilabel problem: ::
968+
969+
>>> y_pred = [0, 2, 1, 2]
970+
>>> y_true = [0, 1, 2, 2]
971+
>>> jaccard_score(y_true, y_pred, average=None)
972+
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
973+
array([1. , 0. , 0.33...])
974+
>>> jaccard_score(y_true, y_pred, average='macro')
975+
0.44...
976+
>>> jaccard_score(y_true, y_pred, average='micro')
977+
0.33...
978+
960979
.. _hinge_loss:
961980

962981
Hinge loss

doc/whats_new/v0.21.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,11 @@ Support for Python 3.4 and below has been officially dropped.
293293
metrics such as recall, specificity, fall out and miss rate.
294294
:issue:`11179` by :user:`Shangwu Yao <ShangwuYao>` and `Joel Nothman`_.
295295

296+
- |Feature| :func:`metrics.jaccard_score` has been added to calculate the
297+
Jaccard coefficient as an evaluation metric for binary, multilabel and
298+
multiclass tasks, with an interface analogous to :func:`metrics.f1_score`.
299+
:issue:`13151` by :user:`Gaurav Dhingra <gxyd>` and `Joel Nothman`_.
300+
296301
- |Efficiency| Faster :func:`metrics.pairwise.pairwise_distances` with `n_jobs`
297302
> 1 by using a thread-based backend, instead of process-based backends.
298303
:issue:`8216` by :user:`Pierre Glaser <pierreglaser>` and
@@ -318,6 +323,11 @@ Support for Python 3.4 and below has been officially dropped.
318323
:issue:`10580` by :user:`Reshama Shaikh <reshamas>` and `Sandra
319324
Mitrovic <SandraMNE>`.
320325

326+
- |API| :func:`metrics.jaccard_similarity_score` is deprecated in favour of
327+
the more consistent :func:`metrics.jaccard_score`. The former behavior for
328+
binary and multiclass targets is broken.
329+
:issue:`13151` by `Joel Nothman`_.
330+
321331
:mod:`sklearn.mixture`
322332
......................
323333

examples/multioutput/plot_classifier_chain_yeast.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
data point has at least one label. As a baseline we first train a logistic
1111
regression classifier for each of the 14 labels. To evaluate the performance of
1212
these classifiers we predict on a held-out test set and calculate the
13-
:ref:`jaccard similarity score <jaccard_similarity_score>`.
13+
:ref:`jaccard score <jaccard_score>` for each sample.
1414
1515
Next we create 10 classifier chains. Each classifier chain contains a
1616
logistic regression model for each of the 14 labels. The models in each
@@ -41,7 +41,7 @@
4141
from sklearn.multioutput import ClassifierChain
4242
from sklearn.model_selection import train_test_split
4343
from sklearn.multiclass import OneVsRestClassifier
44-
from sklearn.metrics import jaccard_similarity_score
44+
from sklearn.metrics import jaccard_score
4545
from sklearn.linear_model import LogisticRegression
4646

4747
print(__doc__)
@@ -58,7 +58,7 @@
5858
ovr = OneVsRestClassifier(base_lr)
5959
ovr.fit(X_train, Y_train)
6060
Y_pred_ovr = ovr.predict(X_test)
61-
ovr_jaccard_score = jaccard_similarity_score(Y_test, Y_pred_ovr)
61+
ovr_jaccard_score = jaccard_score(Y_test, Y_pred_ovr, average='samples')
6262

6363
# Fit an ensemble of logistic regression classifier chains and take the
6464
# take the average prediction of all the chains.
@@ -69,12 +69,14 @@
6969

7070
Y_pred_chains = np.array([chain.predict(X_test) for chain in
7171
chains])
72-
chain_jaccard_scores = [jaccard_similarity_score(Y_test, Y_pred_chain >= .5)
72+
chain_jaccard_scores = [jaccard_score(Y_test, Y_pred_chain >= .5,
73+
average='samples')
7374
for Y_pred_chain in Y_pred_chains]
7475

7576
Y_pred_ensemble = Y_pred_chains.mean(axis=0)
76-
ensemble_jaccard_score = jaccard_similarity_score(Y_test,
77-
Y_pred_ensemble >= .5)
77+
ensemble_jaccard_score = jaccard_score(Y_test,
78+
Y_pred_ensemble >= .5,
79+
average='samples')
7880

7981
model_scores = [ovr_jaccard_score] + chain_jaccard_scores
8082
model_scores.append(ensemble_jaccard_score)

sklearn/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .classification import hamming_loss
2424
from .classification import hinge_loss
2525
from .classification import jaccard_similarity_score
26+
from .classification import jaccard_score
2627
from .classification import log_loss
2728
from .classification import matthews_corrcoef
2829
from .classification import precision_recall_fscore_support
@@ -98,6 +99,7 @@
9899
'hinge_loss',
99100
'homogeneity_completeness_v_measure',
100101
'homogeneity_score',
102+
'jaccard_score',
101103
'jaccard_similarity_score',
102104
'label_ranking_average_precision_score',
103105
'label_ranking_loss',

0 commit comments

Comments
 (0)