Skip to content

Commit 250c6a3

Browse files
ahojnnesogrisel
authored andcommitted
Use assert_equal, assert_less rather than plain assert statement
1 parent 9aeb896 commit 250c6a3

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

sklearn/linear_model/tests/test_ransac.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from numpy.testing import assert_equal, assert_raises
33
from scipy import sparse
44

5+
from sklearn.utils.testing import assert_less
56
from sklearn.linear_model import LinearRegression, RANSACRegressor
67

78

@@ -39,7 +40,8 @@ def test_ransac_inliers_outliers():
3940

4041
def test_ransac_is_data_valid():
4142
def is_data_valid(X, y):
42-
assert X.shape[0] == y.shape[0] == 2
43+
assert_equal(X.shape[0], 2)
44+
assert_equal(y.shape[0], 2)
4345
return False
4446

4547
X = np.random.rand(10, 2)
@@ -56,7 +58,8 @@ def is_data_valid(X, y):
5658

5759
def test_ransac_is_model_valid():
5860
def is_model_valid(estimator, X, y):
59-
assert X.shape[0] == y.shape[0] == 2
61+
assert_equal(X.shape[0], 2)
62+
assert_equal(y.shape[0], 2)
6063
return False
6164

6265
base_estimator = LinearRegression()
@@ -81,7 +84,7 @@ def test_ransac_max_trials():
8184
random_state=0)
8285
assert getattr(ransac_estimator, 'n_trials_', None) is None
8386
ransac_estimator.fit(X, y)
84-
assert ransac_estimator.n_trials_ == 11
87+
assert_equal(ransac_estimator.n_trials_, 11)
8588

8689

8790
def test_ransac_stop_n_inliers():
@@ -91,7 +94,7 @@ def test_ransac_stop_n_inliers():
9194
random_state=0)
9295
ransac_estimator.fit(X, y)
9396

94-
assert ransac_estimator.n_trials_ == 1
97+
assert_equal(ransac_estimator.n_trials_, 1)
9598

9699

97100
def test_ransac_stop_score():
@@ -101,7 +104,7 @@ def test_ransac_stop_score():
101104
random_state=0)
102105
ransac_estimator.fit(X, y)
103106

104-
assert ransac_estimator.n_trials_ == 1
107+
assert_equal(ransac_estimator.n_trials_, 1)
105108

106109

107110
def test_ransac_score():
@@ -115,8 +118,8 @@ def test_ransac_score():
115118
residual_threshold=0.5, random_state=0)
116119
ransac_estimator.fit(X, y)
117120

118-
assert ransac_estimator.score(X[2:], y[2:]) == 1
119-
assert ransac_estimator.score(X[:2], y[:2]) < 1
121+
assert_equal(ransac_estimator.score(X[2:], y[2:]), 1)
122+
assert_less(ransac_estimator.score(X[:2], y[:2]), 1)
120123

121124

122125
def test_ransac_predict():

0 commit comments

Comments
 (0)