Skip to content

Commit 1c11680

Browse files
sameshlrth
authored andcommitted
MAINT:Fix assert raises in sklearn/preprocessing/tests/ (scikit-learn#14717)
1 parent c52b6e1 commit 1c11680

File tree

5 files changed

+234
-185
lines changed

5 files changed

+234
-185
lines changed

sklearn/preprocessing/tests/test_base.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from scipy import sparse
44

55
from sklearn.utils.testing import assert_array_equal
6-
from sklearn.utils.testing import assert_raise_message
76
from sklearn.preprocessing.base import _transform_selected
87
from sklearn.preprocessing.data import Binarizer
98

@@ -61,22 +60,21 @@ def _mutating_transformer(X):
6160
def test_transform_selected_retain_order():
6261
X = [[-1, 1], [2, -2]]
6362

64-
assert_raise_message(ValueError,
65-
"The retain_order option can only be set to True "
66-
"for dense matrices.",
67-
_transform_selected, sparse.csr_matrix(X),
68-
Binarizer().transform, dtype=np.int, selected=[0],
69-
retain_order=True)
63+
err_msg = ("The retain_order option can only be set to True "
64+
"for dense matrices.")
65+
with pytest.raises(ValueError, match=err_msg):
66+
_transform_selected(sparse.csr_matrix(X), Binarizer().transform,
67+
dtype=np.int, selected=[0], retain_order=True)
7068

7169
def transform(X):
7270
return np.hstack((X, [[0], [0]]))
7371

74-
assert_raise_message(ValueError,
75-
"The retain_order option can only be set to True "
76-
"if the dimensions of the input array match the "
77-
"dimensions of the transformed array.",
78-
_transform_selected, X, transform, dtype=np.int,
79-
selected=[0], retain_order=True)
72+
err_msg = ("The retain_order option can only be set to True "
73+
"if the dimensions of the input array match the "
74+
"dimensions of the transformed array.")
75+
with pytest.raises(ValueError, match=err_msg):
76+
_transform_selected(X, transform, dtype=np.int,
77+
selected=[0], retain_order=True)
8078

8179
X_expected = [[-1, 1], [2, 0]]
8280
Xtr = _transform_selected(X, Binarizer().transform, dtype=np.int,

sklearn/preprocessing/tests/test_data.py

Lines changed: 102 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,10 @@
1616

1717
from sklearn.utils import gen_batches
1818

19-
from sklearn.utils.testing import assert_raise_message
2019
from sklearn.utils.testing import assert_almost_equal
2120
from sklearn.utils.testing import assert_array_almost_equal
2221
from sklearn.utils.testing import assert_array_equal
2322
from sklearn.utils.testing import assert_array_less
24-
from sklearn.utils.testing import assert_raises
25-
from sklearn.utils.testing import assert_raises_regex
2623
from sklearn.utils.testing import assert_warns_message
2724
from sklearn.utils.testing import assert_no_warnings
2825
from sklearn.utils.testing import assert_allclose
@@ -693,7 +690,8 @@ def test_min_max_scaler_iris():
693690

694691
# raises on invalid range
695692
scaler = MinMaxScaler(feature_range=(2, 1))
696-
assert_raises(ValueError, scaler.fit, X)
693+
with pytest.raises(ValueError):
694+
scaler.fit(X)
697695

698696

699697
def test_min_max_scaler_zero_variance_features():
@@ -791,8 +789,10 @@ def test_scaler_without_centering():
791789
X_csr = sparse.csr_matrix(X)
792790
X_csc = sparse.csc_matrix(X)
793791

794-
assert_raises(ValueError, StandardScaler().fit, X_csr)
795-
assert_raises(ValueError, StandardScaler().fit, X_csc)
792+
with pytest.raises(ValueError):
793+
StandardScaler().fit(X_csr)
794+
with pytest.raises(ValueError):
795+
StandardScaler().fit(X_csc)
796796

797797
null_transform = StandardScaler(with_mean=False, with_std=False, copy=True)
798798
X_null = null_transform.fit_transform(X_csr)
@@ -1024,30 +1024,38 @@ def test_scale_sparse_with_mean_raise_exception():
10241024
X_csc = sparse.csc_matrix(X)
10251025

10261026
# check scaling and fit with direct calls on sparse data
1027-
assert_raises(ValueError, scale, X_csr, with_mean=True)
1028-
assert_raises(ValueError, StandardScaler(with_mean=True).fit, X_csr)
1027+
with pytest.raises(ValueError):
1028+
scale(X_csr, with_mean=True)
1029+
with pytest.raises(ValueError):
1030+
StandardScaler(with_mean=True).fit(X_csr)
10291031

1030-
assert_raises(ValueError, scale, X_csc, with_mean=True)
1031-
assert_raises(ValueError, StandardScaler(with_mean=True).fit, X_csc)
1032+
with pytest.raises(ValueError):
1033+
scale(X_csc, with_mean=True)
1034+
with pytest.raises(ValueError):
1035+
StandardScaler(with_mean=True).fit(X_csc)
10321036

10331037
# check transform and inverse_transform after a fit on a dense array
10341038
scaler = StandardScaler(with_mean=True).fit(X)
1035-
assert_raises(ValueError, scaler.transform, X_csr)
1036-
assert_raises(ValueError, scaler.transform, X_csc)
1039+
with pytest.raises(ValueError):
1040+
scaler.transform(X_csr)
1041+
with pytest.raises(ValueError):
1042+
scaler.transform(X_csc)
10371043

10381044
X_transformed_csr = sparse.csr_matrix(scaler.transform(X))
1039-
assert_raises(ValueError, scaler.inverse_transform, X_transformed_csr)
1045+
with pytest.raises(ValueError):
1046+
scaler.inverse_transform(X_transformed_csr)
10401047

10411048
X_transformed_csc = sparse.csc_matrix(scaler.transform(X))
1042-
assert_raises(ValueError, scaler.inverse_transform, X_transformed_csc)
1049+
with pytest.raises(ValueError):
1050+
scaler.inverse_transform(X_transformed_csc)
10431051

10441052

10451053
def test_scale_input_finiteness_validation():
10461054
# Check if non finite inputs raise ValueError
10471055
X = [[np.inf, 5, 6, 7, 8]]
1048-
assert_raises_regex(ValueError,
1049-
"Input contains infinity or a value too large",
1050-
scale, X)
1056+
with pytest.raises(ValueError, match="Input contains infinity "
1057+
"or a value too large"):
1058+
scale(X)
10511059

10521060

10531061
def test_robust_scaler_error_sparse():
@@ -1201,57 +1209,63 @@ def test_quantile_transform_check_error():
12011209
[0, 0, 2.6, 4.1, 0, 0, 2.3, 0, 9.5, 0.1]])
12021210
X_neg = sparse.csc_matrix(X_neg)
12031211

1204-
assert_raises_regex(ValueError, "Invalid value for 'n_quantiles': 0.",
1205-
QuantileTransformer(n_quantiles=0).fit, X)
1206-
assert_raises_regex(ValueError, "Invalid value for 'subsample': 0.",
1207-
QuantileTransformer(subsample=0).fit, X)
1208-
assert_raises_regex(ValueError, "The number of quantiles cannot be"
1209-
" greater than the number of samples used. Got"
1210-
" 1000 quantiles and 10 samples.",
1211-
QuantileTransformer(subsample=10).fit, X)
1212+
err_msg = "Invalid value for 'n_quantiles': 0."
1213+
with pytest.raises(ValueError, match=err_msg):
1214+
QuantileTransformer(n_quantiles=0).fit(X)
1215+
err_msg = "Invalid value for 'subsample': 0."
1216+
with pytest.raises(ValueError, match=err_msg):
1217+
QuantileTransformer(subsample=0).fit(X)
1218+
err_msg = ("The number of quantiles cannot be greater than "
1219+
"the number of samples used. Got 1000 quantiles "
1220+
"and 10 samples.")
1221+
with pytest.raises(ValueError, match=err_msg):
1222+
QuantileTransformer(subsample=10).fit(X)
12121223

12131224
transformer = QuantileTransformer(n_quantiles=10)
1214-
assert_raises_regex(ValueError, "QuantileTransformer only accepts "
1215-
"non-negative sparse matrices.",
1216-
transformer.fit, X_neg)
1225+
err_msg = "QuantileTransformer only accepts non-negative sparse matrices."
1226+
with pytest.raises(ValueError, match=err_msg):
1227+
transformer.fit(X_neg)
12171228
transformer.fit(X)
1218-
assert_raises_regex(ValueError, "QuantileTransformer only accepts "
1219-
"non-negative sparse matrices.",
1220-
transformer.transform, X_neg)
1229+
err_msg = "QuantileTransformer only accepts non-negative sparse matrices."
1230+
with pytest.raises(ValueError, match=err_msg):
1231+
transformer.transform(X_neg)
12211232

12221233
X_bad_feat = np.transpose([[0, 25, 50, 0, 0, 0, 75, 0, 0, 100],
12231234
[0, 0, 2.6, 4.1, 0, 0, 2.3, 0, 9.5, 0.1]])
1224-
assert_raises_regex(ValueError, "X does not have the same number of "
1225-
"features as the previously fitted data. Got 2"
1226-
" instead of 3.",
1227-
transformer.transform, X_bad_feat)
1228-
assert_raises_regex(ValueError, "X does not have the same number of "
1229-
"features as the previously fitted data. Got 2"
1230-
" instead of 3.",
1231-
transformer.inverse_transform, X_bad_feat)
1235+
err_msg = ("X does not have the same number of features as the previously"
1236+
" fitted " "data. Got 2 instead of 3.")
1237+
with pytest.raises(ValueError, match=err_msg):
1238+
transformer.transform(X_bad_feat)
1239+
err_msg = ("X does not have the same number of features "
1240+
"as the previously fitted data. Got 2 instead of 3.")
1241+
with pytest.raises(ValueError, match=err_msg):
1242+
transformer.inverse_transform(X_bad_feat)
12321243

12331244
transformer = QuantileTransformer(n_quantiles=10,
12341245
output_distribution='rnd')
12351246
# check that an error is raised at fit time
1236-
assert_raises_regex(ValueError, "'output_distribution' has to be either"
1237-
" 'normal' or 'uniform'. Got 'rnd' instead.",
1238-
transformer.fit, X)
1247+
err_msg = ("'output_distribution' has to be either 'normal' or "
1248+
"'uniform'. Got 'rnd' instead.")
1249+
with pytest.raises(ValueError, match=err_msg):
1250+
transformer.fit(X)
12391251
# check that an error is raised at transform time
12401252
transformer.output_distribution = 'uniform'
12411253
transformer.fit(X)
12421254
X_tran = transformer.transform(X)
12431255
transformer.output_distribution = 'rnd'
1244-
assert_raises_regex(ValueError, "'output_distribution' has to be either"
1245-
" 'normal' or 'uniform'. Got 'rnd' instead.",
1246-
transformer.transform, X)
1256+
err_msg = ("'output_distribution' has to be either 'normal' or 'uniform'."
1257+
" Got 'rnd' instead.")
1258+
with pytest.raises(ValueError, match=err_msg):
1259+
transformer.transform(X)
12471260
# check that an error is raised at inverse_transform time
1248-
assert_raises_regex(ValueError, "'output_distribution' has to be either"
1249-
" 'normal' or 'uniform'. Got 'rnd' instead.",
1250-
transformer.inverse_transform, X_tran)
1261+
err_msg = ("'output_distribution' has to be either 'normal' or 'uniform'."
1262+
" Got 'rnd' instead.")
1263+
with pytest.raises(ValueError, match=err_msg):
1264+
transformer.inverse_transform(X_tran)
12511265
# check that an error is raised if input is scalar
1252-
assert_raise_message(ValueError,
1253-
'Expected 2D array, got scalar array instead',
1254-
transformer.transform, 10)
1266+
with pytest.raises(ValueError,
1267+
match='Expected 2D array, got scalar array instead'):
1268+
transformer.transform(10)
12551269
# check that a warning is raised is n_quantiles > n_samples
12561270
transformer = QuantileTransformer(n_quantiles=100)
12571271
warn_msg = "n_quantiles is set to n_samples"
@@ -1541,8 +1555,8 @@ def test_robust_scaler_invalid_range():
15411555
]:
15421556
scaler = RobustScaler(quantile_range=range_)
15431557

1544-
assert_raises_regex(ValueError, r'Invalid quantile range: \(',
1545-
scaler.fit, iris.data)
1558+
with pytest.raises(ValueError, match=r'Invalid quantile range: \('):
1559+
scaler.fit(iris.data)
15461560

15471561

15481562
def test_scale_function_without_centering():
@@ -1562,7 +1576,8 @@ def test_scale_function_without_centering():
15621576
assert_array_almost_equal(X_scaled, X_csc_scaled.toarray())
15631577

15641578
# raises value error on axis != 0
1565-
assert_raises(ValueError, scale, X_csr, with_mean=False, axis=1)
1579+
with pytest.raises(ValueError):
1580+
scale(X_csr, with_mean=False, axis=1)
15661581

15671582
assert_array_almost_equal(X_scaled.mean(axis=0),
15681583
[0., -0.01, 2.24, -0.35, -0.78], 2)
@@ -1951,8 +1966,10 @@ def test_normalize():
19511966
X = np.random.RandomState(37).randn(3, 2)
19521967
assert_array_equal(normalize(X, copy=False),
19531968
normalize(X.T, axis=0, copy=False).T)
1954-
assert_raises(ValueError, normalize, [[0]], axis=2)
1955-
assert_raises(ValueError, normalize, [[0]], norm='l3')
1969+
with pytest.raises(ValueError):
1970+
normalize([[0]], axis=2)
1971+
with pytest.raises(ValueError):
1972+
normalize([[0]], norm='l3')
19561973

19571974
rs = np.random.RandomState(0)
19581975
X_dense = rs.randn(10, 5)
@@ -1987,8 +2004,8 @@ def test_normalize():
19872004

19882005
X_sparse = sparse.csr_matrix(X_dense)
19892006
for norm in ('l1', 'l2'):
1990-
assert_raises(NotImplementedError, normalize, X_sparse,
1991-
norm=norm, return_norm=True)
2007+
with pytest.raises(NotImplementedError):
2008+
normalize(X_sparse, norm=norm, return_norm=True)
19922009
_, norms = normalize(X_sparse, norm='max', return_norm=True)
19932010
assert_array_almost_equal(norms, np.array([4.0, 1.0, 3.0]))
19942011

@@ -2045,7 +2062,8 @@ def test_binarizer():
20452062
X_bin = binarizer.transform(X)
20462063

20472064
# Cannot use threshold < 0 for sparse
2048-
assert_raises(ValueError, binarizer.transform, sparse.csc_matrix(X))
2065+
with pytest.raises(ValueError):
2066+
binarizer.transform(sparse.csc_matrix(X))
20492067

20502068

20512069
def test_center_kernel():
@@ -2151,16 +2169,19 @@ def test_quantile_transform_valid_axis():
21512169
[2, 4, 6, 8, 10],
21522170
[2.6, 4.1, 2.3, 9.5, 0.1]])
21532171

2154-
assert_raises_regex(ValueError, "axis should be either equal to 0 or 1"
2155-
". Got axis=2", quantile_transform, X.T, axis=2)
2172+
with pytest.raises(ValueError, match="axis should be either equal "
2173+
"to 0 or 1. Got axis=2"):
2174+
quantile_transform(X.T, axis=2)
21562175

21572176

21582177
@pytest.mark.parametrize("method", ['box-cox', 'yeo-johnson'])
21592178
def test_power_transformer_notfitted(method):
21602179
pt = PowerTransformer(method=method)
21612180
X = np.abs(X_1col)
2162-
assert_raises(NotFittedError, pt.transform, X)
2163-
assert_raises(NotFittedError, pt.inverse_transform, X)
2181+
with pytest.raises(NotFittedError):
2182+
pt.transform(X)
2183+
with pytest.raises(NotFittedError):
2184+
pt.inverse_transform(X)
21642185

21652186

21662187
@pytest.mark.parametrize('method', ['box-cox', 'yeo-johnson'])
@@ -2241,23 +2262,23 @@ def test_power_transformer_boxcox_strictly_positive_exception():
22412262
X_with_negatives = X_2d
22422263
not_positive_message = 'strictly positive'
22432264

2244-
assert_raise_message(ValueError, not_positive_message,
2245-
pt.transform, X_with_negatives)
2265+
with pytest.raises(ValueError, match=not_positive_message):
2266+
pt.transform(X_with_negatives)
22462267

2247-
assert_raise_message(ValueError, not_positive_message,
2248-
pt.fit, X_with_negatives)
2268+
with pytest.raises(ValueError, match=not_positive_message):
2269+
pt.fit(X_with_negatives)
22492270

2250-
assert_raise_message(ValueError, not_positive_message,
2251-
power_transform, X_with_negatives, 'box-cox')
2271+
with pytest.raises(ValueError, match=not_positive_message):
2272+
power_transform(X_with_negatives, 'box-cox')
22522273

2253-
assert_raise_message(ValueError, not_positive_message,
2254-
pt.transform, np.zeros(X_2d.shape))
2274+
with pytest.raises(ValueError, match=not_positive_message):
2275+
pt.transform(np.zeros(X_2d.shape))
22552276

2256-
assert_raise_message(ValueError, not_positive_message,
2257-
pt.fit, np.zeros(X_2d.shape))
2277+
with pytest.raises(ValueError, match=not_positive_message):
2278+
pt.fit(np.zeros(X_2d.shape))
22582279

2259-
assert_raise_message(ValueError, not_positive_message,
2260-
power_transform, np.zeros(X_2d.shape), 'box-cox')
2280+
with pytest.raises(ValueError, match=not_positive_message):
2281+
power_transform(np.zeros(X_2d.shape), 'box-cox')
22612282

22622283

22632284
@pytest.mark.parametrize('X', [X_2d, np.abs(X_2d), -np.abs(X_2d),
@@ -2277,11 +2298,11 @@ def test_power_transformer_shape_exception(method):
22772298
# than during fitting
22782299
wrong_shape_message = 'Input data has a different number of features'
22792300

2280-
assert_raise_message(ValueError, wrong_shape_message,
2281-
pt.transform, X[:, 0:1])
2301+
with pytest.raises(ValueError, match=wrong_shape_message):
2302+
pt.transform(X[:, 0:1])
22822303

2283-
assert_raise_message(ValueError, wrong_shape_message,
2284-
pt.inverse_transform, X[:, 0:1])
2304+
with pytest.raises(ValueError, match=wrong_shape_message):
2305+
pt.inverse_transform(X[:, 0:1])
22852306

22862307

22872308
def test_power_transformer_method_exception():
@@ -2290,8 +2311,8 @@ def test_power_transformer_method_exception():
22902311

22912312
# An exception should be raised if PowerTransformer.method isn't valid
22922313
bad_method_message = "'method' must be one of"
2293-
assert_raise_message(ValueError, bad_method_message,
2294-
pt.fit, X)
2314+
with pytest.raises(ValueError, match=bad_method_message):
2315+
pt.fit(X)
22952316

22962317

22972318
def test_power_transformer_lambda_zero():

0 commit comments

Comments
 (0)