Skip to content

Commit 70f9e9f

Browse files
NicolasHugadrinjalali
authored andcommitted
MNT deprecate some more utils in estimator_checks.py (scikit-learn#15029)
* deprecated notanarray * deprecated is_public_parameter * deprecated pairwise_estimator_convert_X * deprecated checking_parameters * pep8 * use decorator for class
1 parent e6a4dc9 commit 70f9e9f

File tree

6 files changed

+97
-58
lines changed

6 files changed

+97
-58
lines changed

sklearn/tests/test_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from sklearn.utils.testing import SkipTest
3333
from sklearn.utils.estimator_checks import (
3434
_construct_instance,
35-
set_checking_parameters,
35+
_set_checking_parameters,
3636
_set_check_estimator_ids,
3737
check_parameters_default_constructible,
3838
check_class_weight_balanced_linear_classifier,
@@ -93,7 +93,7 @@ def test_estimators(estimator, check):
9393
# Common tests for estimator instances
9494
with ignore_warnings(category=(DeprecationWarning, ConvergenceWarning,
9595
UserWarning, FutureWarning)):
96-
set_checking_parameters(estimator)
96+
_set_checking_parameters(estimator)
9797
check(estimator)
9898

9999

sklearn/utils/estimator_checks.py

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,13 @@ def _boston_subset(n_samples=200):
438438
return BOSTON
439439

440440

441+
@deprecated("set_checking_parameters is deprecated in version "
442+
"0.22 and will be removed in version 0.24.")
441443
def set_checking_parameters(estimator):
444+
_set_checking_parameters(estimator)
445+
446+
447+
def _set_checking_parameters(estimator):
442448
# set parameters to speed up some estimators and
443449
# avoid deprecated behaviour
444450
params = estimator.get_params()
@@ -519,7 +525,7 @@ def set_checking_parameters(estimator):
519525
estimator.set_params(handle_unknown='ignore')
520526

521527

522-
class NotAnArray:
528+
class _NotAnArray:
523529
"""An object that is convertible to an array
524530
525531
Parameters
@@ -535,6 +541,13 @@ def __array__(self, dtype=None):
535541
return self.data
536542

537543

544+
@deprecated("NotAnArray is deprecated in version "
545+
"0.22 and will be removed in version 0.24.")
546+
class NotAnArray(_NotAnArray):
547+
# TODO: remove in 0.24
548+
pass
549+
550+
538551
def _is_pairwise(estimator):
539552
"""Returns True if estimator has a _pairwise attribute set to True.
540553
@@ -569,7 +582,13 @@ def _is_pairwise_metric(estimator):
569582
return bool(metric == 'precomputed')
570583

571584

585+
@deprecated("pairwise_estimator_convert_X is deprecated in version "
586+
"0.22 and will be removed in version 0.24.")
572587
def pairwise_estimator_convert_X(X, estimator, kernel=linear_kernel):
588+
return _pairwise_estimator_convert_X(X, estimator, kernel)
589+
590+
591+
def _pairwise_estimator_convert_X(X, estimator, kernel=linear_kernel):
573592

574593
if _is_pairwise_metric(estimator):
575594
return pairwise_distances(X, metric='euclidean')
@@ -616,7 +635,7 @@ def check_estimator_sparse_data(name, estimator_orig):
616635
rng = np.random.RandomState(0)
617636
X = rng.rand(40, 10)
618637
X[X < .8] = 0
619-
X = pairwise_estimator_convert_X(X, estimator_orig)
638+
X = _pairwise_estimator_convert_X(X, estimator_orig)
620639
X_csr = sparse.csr_matrix(X)
621640
tags = _safe_tags(estimator_orig)
622641
if tags['binary_only']:
@@ -681,7 +700,7 @@ def check_sample_weights_pandas_series(name, estimator_orig):
681700
X = np.array([[1, 1], [1, 2], [1, 3], [1, 4],
682701
[2, 1], [2, 2], [2, 3], [2, 4],
683702
[3, 1], [3, 2], [3, 3], [3, 4]])
684-
X = pd.DataFrame(pairwise_estimator_convert_X(X, estimator_orig))
703+
X = pd.DataFrame(_pairwise_estimator_convert_X(X, estimator_orig))
685704
y = pd.Series([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2])
686705
weights = pd.Series([1] * 12)
687706
if _safe_tags(estimator, "multioutput_only"):
@@ -705,7 +724,7 @@ def check_sample_weights_list(name, estimator_orig):
705724
estimator = clone(estimator_orig)
706725
rnd = np.random.RandomState(0)
707726
n_samples = 30
708-
X = pairwise_estimator_convert_X(rnd.uniform(size=(n_samples, 3)),
727+
X = _pairwise_estimator_convert_X(rnd.uniform(size=(n_samples, 3)),
709728
estimator_orig)
710729
if _safe_tags(estimator, 'binary_only'):
711730
y = np.arange(n_samples) % 2
@@ -759,7 +778,7 @@ def check_sample_weights_invariance(name, estimator_orig):
759778
def check_dtype_object(name, estimator_orig):
760779
# check that estimators treat dtype object as numeric if possible
761780
rng = np.random.RandomState(0)
762-
X = pairwise_estimator_convert_X(rng.rand(40, 10), estimator_orig)
781+
X = _pairwise_estimator_convert_X(rng.rand(40, 10), estimator_orig)
763782
X = X.astype(object)
764783
tags = _safe_tags(estimator_orig)
765784
if tags['binary_only']:
@@ -818,7 +837,7 @@ def check_dict_unchanged(name, estimator_orig):
818837
else:
819838
X = 2 * rnd.uniform(size=(20, 3))
820839

821-
X = pairwise_estimator_convert_X(X, estimator_orig)
840+
X = _pairwise_estimator_convert_X(X, estimator_orig)
822841

823842
y = X[:, 0].astype(np.int)
824843
estimator = clone(estimator_orig)
@@ -844,7 +863,13 @@ def check_dict_unchanged(name, estimator_orig):
844863
'Estimator changes __dict__ during %s' % method)
845864

846865

866+
@deprecated("is_public_parameter is deprecated in version "
867+
"0.22 and will be removed in version 0.24.")
847868
def is_public_parameter(attr):
869+
return _is_public_parameter(attr)
870+
871+
872+
def _is_public_parameter(attr):
848873
return not (attr.startswith('_') or attr.endswith('_'))
849874

850875

@@ -857,7 +882,7 @@ def check_dont_overwrite_parameters(name, estimator_orig):
857882
estimator = clone(estimator_orig)
858883
rnd = np.random.RandomState(0)
859884
X = 3 * rnd.uniform(size=(20, 3))
860-
X = pairwise_estimator_convert_X(X, estimator_orig)
885+
X = _pairwise_estimator_convert_X(X, estimator_orig)
861886
y = X[:, 0].astype(np.int)
862887
if _safe_tags(estimator, 'binary_only'):
863888
y[y == 2] = 1
@@ -875,7 +900,7 @@ def check_dont_overwrite_parameters(name, estimator_orig):
875900
dict_after_fit = estimator.__dict__
876901

877902
public_keys_after_fit = [key for key in dict_after_fit.keys()
878-
if is_public_parameter(key)]
903+
if _is_public_parameter(key)]
879904

880905
attrs_added_by_fit = [key for key in public_keys_after_fit
881906
if key not in dict_before_fit.keys()]
@@ -908,7 +933,7 @@ def check_fit2d_predict1d(name, estimator_orig):
908933
# check by fitting a 2d array and predicting with a 1d array
909934
rnd = np.random.RandomState(0)
910935
X = 3 * rnd.uniform(size=(20, 3))
911-
X = pairwise_estimator_convert_X(X, estimator_orig)
936+
X = _pairwise_estimator_convert_X(X, estimator_orig)
912937
y = X[:, 0].astype(np.int)
913938
tags = _safe_tags(estimator_orig)
914939
if tags['binary_only']:
@@ -959,7 +984,7 @@ def check_methods_subset_invariance(name, estimator_orig):
959984
# on mini batches or the whole set
960985
rnd = np.random.RandomState(0)
961986
X = 3 * rnd.uniform(size=(20, 3))
962-
X = pairwise_estimator_convert_X(X, estimator_orig)
987+
X = _pairwise_estimator_convert_X(X, estimator_orig)
963988
y = X[:, 0].astype(np.int)
964989
if _safe_tags(estimator_orig, 'binary_only'):
965990
y[y == 2] = 1
@@ -1001,7 +1026,7 @@ def check_fit2d_1sample(name, estimator_orig):
10011026
# the number of samples or the number of classes.
10021027
rnd = np.random.RandomState(0)
10031028
X = 3 * rnd.uniform(size=(1, 10))
1004-
X = pairwise_estimator_convert_X(X, estimator_orig)
1029+
X = _pairwise_estimator_convert_X(X, estimator_orig)
10051030

10061031
y = X[:, 0].astype(np.int)
10071032
estimator = clone(estimator_orig)
@@ -1034,7 +1059,7 @@ def check_fit2d_1feature(name, estimator_orig):
10341059
# informative message
10351060
rnd = np.random.RandomState(0)
10361061
X = 3 * rnd.uniform(size=(10, 1))
1037-
X = pairwise_estimator_convert_X(X, estimator_orig)
1062+
X = _pairwise_estimator_convert_X(X, estimator_orig)
10381063
y = X[:, 0].astype(np.int)
10391064
estimator = clone(estimator_orig)
10401065
y = _enforce_estimator_tags_y(estimator, y)
@@ -1090,7 +1115,7 @@ def check_transformer_general(name, transformer, readonly_memmap=False):
10901115
random_state=0, n_features=2, cluster_std=0.1)
10911116
X = StandardScaler().fit_transform(X)
10921117
X -= X.min()
1093-
X = pairwise_estimator_convert_X(X, transformer)
1118+
X = _pairwise_estimator_convert_X(X, transformer)
10941119

10951120
if readonly_memmap:
10961121
X, y = create_memmap_backed_data([X, y])
@@ -1106,9 +1131,9 @@ def check_transformer_data_not_an_array(name, transformer):
11061131
# We need to make sure that we have non negative data, for things
11071132
# like NMF
11081133
X -= X.min() - .1
1109-
X = pairwise_estimator_convert_X(X, transformer)
1110-
this_X = NotAnArray(X)
1111-
this_y = NotAnArray(np.asarray(y))
1134+
X = _pairwise_estimator_convert_X(X, transformer)
1135+
this_X = _NotAnArray(X)
1136+
this_y = _NotAnArray(np.asarray(y))
11121137
_check_transformer(name, transformer, this_X, this_y)
11131138
# try the same with some list
11141139
_check_transformer(name, transformer, X.tolist(), y.tolist())
@@ -1212,7 +1237,7 @@ def check_pipeline_consistency(name, estimator_orig):
12121237
X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
12131238
random_state=0, n_features=2, cluster_std=0.1)
12141239
X -= X.min()
1215-
X = pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel)
1240+
X = _pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel)
12161241
estimator = clone(estimator_orig)
12171242
y = _enforce_estimator_tags_y(estimator, y)
12181243
set_random_state(estimator)
@@ -1238,7 +1263,7 @@ def check_fit_score_takes_y(name, estimator_orig):
12381263
rnd = np.random.RandomState(0)
12391264
n_samples = 30
12401265
X = rnd.uniform(size=(n_samples, 3))
1241-
X = pairwise_estimator_convert_X(X, estimator_orig)
1266+
X = _pairwise_estimator_convert_X(X, estimator_orig)
12421267
if _safe_tags(estimator_orig, 'binary_only'):
12431268
y = np.arange(n_samples) % 2
12441269
else:
@@ -1267,7 +1292,7 @@ def check_fit_score_takes_y(name, estimator_orig):
12671292
def check_estimators_dtypes(name, estimator_orig):
12681293
rnd = np.random.RandomState(0)
12691294
X_train_32 = 3 * rnd.uniform(size=(20, 5)).astype(np.float32)
1270-
X_train_32 = pairwise_estimator_convert_X(X_train_32, estimator_orig)
1295+
X_train_32 = _pairwise_estimator_convert_X(X_train_32, estimator_orig)
12711296
X_train_64 = X_train_32.astype(np.float64)
12721297
X_train_int_64 = X_train_32.astype(np.int64)
12731298
X_train_int_32 = X_train_32.astype(np.int32)
@@ -1315,7 +1340,7 @@ def check_estimators_empty_data_messages(name, estimator_orig):
13151340
def check_estimators_nan_inf(name, estimator_orig):
13161341
# Checks that Estimator X's do not contain NaN or inf.
13171342
rnd = np.random.RandomState(0)
1318-
X_train_finite = pairwise_estimator_convert_X(rnd.uniform(size=(10, 3)),
1343+
X_train_finite = _pairwise_estimator_convert_X(rnd.uniform(size=(10, 3)),
13191344
estimator_orig)
13201345
X_train_nan = rnd.uniform(size=(10, 3))
13211346
X_train_nan[0, 0] = np.nan
@@ -1406,7 +1431,7 @@ def check_estimators_pickle(name, estimator_orig):
14061431

14071432
# some estimators can't do features less than 0
14081433
X -= X.min()
1409-
X = pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel)
1434+
X = _pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel)
14101435

14111436
tags = _safe_tags(estimator_orig)
14121437
# include NaN values when the estimator should deal with them
@@ -1604,7 +1629,7 @@ def check_classifiers_train(name, classifier_orig, readonly_memmap=False):
16041629
n_classes = len(classes)
16051630
n_samples, n_features = X.shape
16061631
classifier = clone(classifier_orig)
1607-
X = pairwise_estimator_convert_X(X, classifier)
1632+
X = _pairwise_estimator_convert_X(X, classifier)
16081633
y = _enforce_estimator_tags_y(classifier, y)
16091634

16101635
set_random_state(classifier)
@@ -1807,7 +1832,7 @@ def check_estimators_fit_returns_self(name, estimator_orig,
18071832
X, y = make_blobs(random_state=0, n_samples=21, centers=n_centers)
18081833
# some want non-negative input
18091834
X -= X.min()
1810-
X = pairwise_estimator_convert_X(X, estimator_orig)
1835+
X = _pairwise_estimator_convert_X(X, estimator_orig)
18111836

18121837
estimator = clone(estimator_orig)
18131838
y = _enforce_estimator_tags_y(estimator, y)
@@ -1843,7 +1868,7 @@ def check_supervised_y_2d(name, estimator_orig):
18431868
return
18441869
rnd = np.random.RandomState(0)
18451870
n_samples = 30
1846-
X = pairwise_estimator_convert_X(
1871+
X = _pairwise_estimator_convert_X(
18471872
rnd.uniform(size=(n_samples, 3)), estimator_orig
18481873
)
18491874
if tags['binary_only']:
@@ -1943,8 +1968,8 @@ def check_classifiers_classes(name, classifier_orig):
19431968
X_binary = X_multiclass[y_multiclass != 2]
19441969
y_binary = y_multiclass[y_multiclass != 2]
19451970

1946-
X_multiclass = pairwise_estimator_convert_X(X_multiclass, classifier_orig)
1947-
X_binary = pairwise_estimator_convert_X(X_binary, classifier_orig)
1971+
X_multiclass = _pairwise_estimator_convert_X(X_multiclass, classifier_orig)
1972+
X_binary = _pairwise_estimator_convert_X(X_binary, classifier_orig)
19481973

19491974
labels_multiclass = ["one", "two", "three"]
19501975
labels_binary = ["one", "two"]
@@ -1970,7 +1995,7 @@ def check_classifiers_classes(name, classifier_orig):
19701995
@ignore_warnings(category=(DeprecationWarning, FutureWarning))
19711996
def check_regressors_int(name, regressor_orig):
19721997
X, _ = _boston_subset()
1973-
X = pairwise_estimator_convert_X(X[:50], regressor_orig)
1998+
X = _pairwise_estimator_convert_X(X[:50], regressor_orig)
19741999
rnd = np.random.RandomState(0)
19752000
y = rnd.randint(3, size=X.shape[0])
19762001
y = _enforce_estimator_tags_y(regressor_orig, y)
@@ -1998,7 +2023,7 @@ def check_regressors_int(name, regressor_orig):
19982023
@ignore_warnings(category=(DeprecationWarning, FutureWarning))
19992024
def check_regressors_train(name, regressor_orig, readonly_memmap=False):
20002025
X, y = _boston_subset()
2001-
X = pairwise_estimator_convert_X(X, regressor_orig)
2026+
X = _pairwise_estimator_convert_X(X, regressor_orig)
20022027
y = StandardScaler().fit_transform(y.reshape(-1, 1)) # X is already scaled
20032028
y = y.ravel()
20042029
regressor = clone(regressor_orig)
@@ -2047,7 +2072,7 @@ def check_regressors_no_decision_function(name, regressor_orig):
20472072
regressor = clone(regressor_orig)
20482073

20492074
X = rng.normal(size=(10, 4))
2050-
X = pairwise_estimator_convert_X(X, regressor_orig)
2075+
X = _pairwise_estimator_convert_X(X, regressor_orig)
20512076
y = _enforce_estimator_tags_y(regressor, X[:, 0])
20522077

20532078
if hasattr(regressor, "n_components"):
@@ -2186,7 +2211,7 @@ def check_estimators_overwrite_params(name, estimator_orig):
21862211
X, y = make_blobs(random_state=0, n_samples=21, centers=n_centers)
21872212
# some want non-negative input
21882213
X -= X.min()
2189-
X = pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel)
2214+
X = _pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel)
21902215
estimator = clone(estimator_orig)
21912216
y = _enforce_estimator_tags_y(estimator, y)
21922217

@@ -2277,7 +2302,7 @@ def check_sparsify_coefficients(name, estimator_orig):
22772302
def check_classifier_data_not_an_array(name, estimator_orig):
22782303
X = np.array([[3, 0], [0, 1], [0, 2], [1, 1], [1, 2], [2, 1],
22792304
[0, 3], [1, 0], [2, 0], [4, 4], [2, 3], [3, 2]])
2280-
X = pairwise_estimator_convert_X(X, estimator_orig)
2305+
X = _pairwise_estimator_convert_X(X, estimator_orig)
22812306
y = [1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2]
22822307
y = _enforce_estimator_tags_y(estimator_orig, y)
22832308
check_estimators_data_not_an_array(name, estimator_orig, X, y)
@@ -2286,7 +2311,7 @@ def check_classifier_data_not_an_array(name, estimator_orig):
22862311
@ignore_warnings(category=DeprecationWarning)
22872312
def check_regressor_data_not_an_array(name, estimator_orig):
22882313
X, y = _boston_subset(n_samples=50)
2289-
X = pairwise_estimator_convert_X(X, estimator_orig)
2314+
X = _pairwise_estimator_convert_X(X, estimator_orig)
22902315
y = _enforce_estimator_tags_y(estimator_orig, y)
22912316
check_estimators_data_not_an_array(name, estimator_orig, X, y)
22922317

@@ -2303,8 +2328,8 @@ def check_estimators_data_not_an_array(name, estimator_orig, X, y):
23032328
set_random_state(estimator_1)
23042329
set_random_state(estimator_2)
23052330

2306-
y_ = NotAnArray(np.asarray(y))
2307-
X_ = NotAnArray(np.asarray(X))
2331+
y_ = _NotAnArray(np.asarray(y))
2332+
X_ = _NotAnArray(np.asarray(X))
23082333

23092334
# fit
23102335
estimator_1.fit(X_, y_)
@@ -2638,7 +2663,7 @@ def check_fit_idempotent(name, estimator_orig):
26382663

26392664
n_samples = 100
26402665
X = rng.normal(loc=100, size=(n_samples, 2))
2641-
X = pairwise_estimator_convert_X(X, estimator)
2666+
X = _pairwise_estimator_convert_X(X, estimator)
26422667
if is_regressor(estimator_orig):
26432668
y = rng.normal(size=n_samples)
26442669
else:

sklearn/utils/tests/test_deprecated_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33

44
from sklearn.dummy import DummyClassifier
55
from sklearn.utils.estimator_checks import choose_check_classifiers_labels
6+
from sklearn.utils.estimator_checks import NotAnArray
67
from sklearn.utils.estimator_checks import enforce_estimator_tags_y
8+
from sklearn.utils.estimator_checks import is_public_parameter
9+
from sklearn.utils.estimator_checks import pairwise_estimator_convert_X
10+
from sklearn.utils.estimator_checks import set_checking_parameters
711

812

913
# This file tests the utils that are deprecated
@@ -17,3 +21,23 @@ def test_choose_check_classifiers_labels_deprecated():
1721
def test_enforce_estimator_tags_y():
1822
with pytest.warns(DeprecationWarning, match="removed in version 0.24"):
1923
enforce_estimator_tags_y(DummyClassifier(), np.array([0, 1]))
24+
25+
26+
def test_notanarray():
27+
with pytest.warns(DeprecationWarning, match="removed in version 0.24"):
28+
NotAnArray([1, 2])
29+
30+
31+
def test_is_public_parameter():
32+
with pytest.warns(DeprecationWarning, match="removed in version 0.24"):
33+
is_public_parameter('hello')
34+
35+
36+
def test_pairwise_estimator_convert_X():
37+
with pytest.warns(DeprecationWarning, match="removed in version 0.24"):
38+
pairwise_estimator_convert_X([[1, 2]], DummyClassifier())
39+
40+
41+
def test_set_checking_parameters():
42+
with pytest.warns(DeprecationWarning, match="removed in version 0.24"):
43+
set_checking_parameters(DummyClassifier())

0 commit comments

Comments
 (0)