Skip to content

Commit f0cce3a

Browse files
authored
FIX Infer pos_label in Display method from_cv_results (scikit-learn#32372)
1 parent 8d4fc4a commit f0cce3a

File tree

5 files changed

+23
-35
lines changed

5 files changed

+23
-35
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- :meth:`metrics.RocCurveDisplay.from_cv_results` will now infer `pos_label` as
2+
`estimator.classes_[-1]`, using the estimator from `cv_results`, when
3+
`pos_label=None`. Previously, an error was raised when `pos_label=None`.
4+
By :user:`Lucy Liu <lucyleeow>`.

sklearn/metrics/_plot/roc_curve.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -665,8 +665,8 @@ def from_cv_results(
665665
666666
pos_label : int, float, bool or str, default=None
667667
The class considered as the positive class when computing the ROC AUC
668-
metrics. By default, `estimators.classes_[1]` is considered
669-
as the positive class.
668+
metrics. By default, `estimator.classes_[1]` (using `estimator` from
669+
`cv_results`) is considered as the positive class.
670670
671671
ax : matplotlib axes, default=None
672672
Axes object to plot on. If `None`, a new figure and axes is
@@ -730,24 +730,23 @@ def from_cv_results(
730730
<...>
731731
>>> plt.show()
732732
"""
733-
pos_label_ = cls._validate_from_cv_results_params(
733+
cls._validate_from_cv_results_params(
734734
cv_results,
735735
X,
736736
y,
737737
sample_weight=sample_weight,
738-
pos_label=pos_label,
739738
)
740739

741740
fpr_folds, tpr_folds, auc_folds = [], [], []
742741
for estimator, test_indices in zip(
743742
cv_results["estimator"], cv_results["indices"]["test"]
744743
):
745744
y_true = _safe_indexing(y, test_indices)
746-
y_pred, _ = _get_response_values_binary(
745+
y_pred, pos_label_ = _get_response_values_binary(
747746
estimator,
748747
_safe_indexing(X, test_indices),
749748
response_method=response_method,
750-
pos_label=pos_label_,
749+
pos_label=pos_label,
751750
)
752751
sample_weight_fold = (
753752
None

sklearn/metrics/_plot/tests/test_roc_curve_display.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sklearn import clone
99
from sklearn.compose import make_column_transformer
1010
from sklearn.datasets import load_breast_cancer, make_classification
11-
from sklearn.exceptions import NotFittedError
11+
from sklearn.exceptions import NotFittedError, UndefinedMetricWarning
1212
from sklearn.linear_model import LogisticRegression
1313
from sklearn.metrics import RocCurveDisplay, auc, roc_curve
1414
from sklearn.model_selection import cross_validate, train_test_split
@@ -264,7 +264,7 @@ def test_roc_curve_from_cv_results_param_validation(pyplot, data_binary):
264264

265265
# `pos_label` inconsistency
266266
y_multi[y_multi == 1] = 2
267-
with pytest.raises(ValueError, match=r"y takes value in \{0, 2\}"):
267+
with pytest.warns(UndefinedMetricWarning, match="No positive samples in y_true"):
268268
RocCurveDisplay.from_cv_results(cv_results, X, y_multi)
269269

270270
# `name` is list while `curve_kwargs` is None or dict
@@ -588,6 +588,18 @@ def test_roc_curve_from_cv_results_curve_kwargs(pyplot, data_binary, curve_kwarg
588588
assert color == curve_kwargs[idx]["c"]
589589

590590

591+
def test_roc_curve_from_cv_results_pos_label_inferred(pyplot, data_binary):
592+
"""Check `pos_label` inferred correctly by `from_cv_results(pos_label=None)`."""
593+
X, y = data_binary
594+
cv_results = cross_validate(
595+
LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True
596+
)
597+
598+
disp = RocCurveDisplay.from_cv_results(cv_results, X, y, pos_label=None)
599+
# Should be `estimator.classes_[1]`
600+
assert disp.pos_label == 1
601+
602+
591603
def _check_chance_level(plot_chance_level, chance_level_kw, display):
592604
"""Check chance level line and line styles correct."""
593605
import matplotlib as mpl

sklearn/utils/_plotting.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def _validate_from_cv_results_params(
7777
y,
7878
*,
7979
sample_weight,
80-
pos_label,
8180
):
8281
check_matplotlib_support(f"{cls.__name__}.from_cv_results")
8382

@@ -107,14 +106,6 @@ def _validate_from_cv_results_params(
107106
)
108107
check_consistent_length(X, y, sample_weight)
109108

110-
try:
111-
pos_label = _check_pos_label_consistency(pos_label, y)
112-
except ValueError as e:
113-
# Adapt error message
114-
raise ValueError(str(e).replace("y_true", "y"))
115-
116-
return pos_label
117-
118109
@staticmethod
119110
def _get_legend_label(curve_legend_metric, curve_name, legend_metric_name):
120111
"""Helper to get legend label using `name` and `legend_metric`"""

sklearn/utils/tests/test_plotting.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ def test_validate_from_predictions_params_returns(pyplot, name, pos_label, y_tru
128128
"X": np.array([[1, 2], [3, 4]]),
129129
"y": np.array([0, 1]),
130130
"sample_weight": None,
131-
"pos_label": None,
132131
},
133132
"`cv_results` does not contain one of the following",
134133
),
@@ -142,7 +141,6 @@ def test_validate_from_predictions_params_returns(pyplot, name, pos_label, y_tru
142141
"X": np.array([[1, 2]]),
143142
"y": np.array([0, 1]),
144143
"sample_weight": None,
145-
"pos_label": None,
146144
},
147145
"`X` does not contain the correct number of",
148146
),
@@ -156,7 +154,6 @@ def test_validate_from_predictions_params_returns(pyplot, name, pos_label, y_tru
156154
# `y` not binary
157155
"y": np.array([0, 2, 1, 3]),
158156
"sample_weight": None,
159-
"pos_label": None,
160157
},
161158
"The target `y` is not binary",
162159
),
@@ -170,24 +167,9 @@ def test_validate_from_predictions_params_returns(pyplot, name, pos_label, y_tru
170167
"y": np.array([0, 1, 0, 1]),
171168
# `sample_weight` wrong length
172169
"sample_weight": np.array([0.5]),
173-
"pos_label": None,
174170
},
175171
"Found input variables with inconsistent",
176172
),
177-
(
178-
{
179-
"cv_results": {
180-
"estimator": "dummy",
181-
"indices": {"test": [[1, 2], [1, 2]], "train": [[3, 4], [3, 4]]},
182-
},
183-
"X": np.array([1, 2, 3, 4]),
184-
"y": np.array([2, 3, 2, 3]),
185-
"sample_weight": None,
186-
# Not specified when `y` not in {0, 1} or {-1, 1}
187-
"pos_label": None,
188-
},
189-
"y takes value in {2, 3} and pos_label is not specified",
190-
),
191173
],
192174
)
193175
def test_validate_from_cv_results_params(pyplot, params, err_msg):

0 commit comments

Comments
 (0)