22from numpy .testing import assert_equal , assert_raises
33from scipy import sparse
44
5+ from sklearn .utils .testing import assert_less
56from sklearn .linear_model import LinearRegression , RANSACRegressor
67
78
@@ -39,7 +40,8 @@ def test_ransac_inliers_outliers():
3940
4041def test_ransac_is_data_valid ():
4142 def is_data_valid (X , y ):
42- assert X .shape [0 ] == y .shape [0 ] == 2
43+ assert_equal (X .shape [0 ], 2 )
44+ assert_equal (y .shape [0 ], 2 )
4345 return False
4446
4547 X = np .random .rand (10 , 2 )
@@ -56,7 +58,8 @@ def is_data_valid(X, y):
5658
5759def test_ransac_is_model_valid ():
5860 def is_model_valid (estimator , X , y ):
59- assert X .shape [0 ] == y .shape [0 ] == 2
61+ assert_equal (X .shape [0 ], 2 )
62+ assert_equal (y .shape [0 ], 2 )
6063 return False
6164
6265 base_estimator = LinearRegression ()
@@ -81,7 +84,7 @@ def test_ransac_max_trials():
8184 random_state = 0 )
8285 assert getattr (ransac_estimator , 'n_trials_' , None ) is None
8386 ransac_estimator .fit (X , y )
84- assert ransac_estimator .n_trials_ == 11
87+ assert_equal ( ransac_estimator .n_trials_ , 11 )
8588
8689
8790def test_ransac_stop_n_inliers ():
@@ -91,7 +94,7 @@ def test_ransac_stop_n_inliers():
9194 random_state = 0 )
9295 ransac_estimator .fit (X , y )
9396
94- assert ransac_estimator .n_trials_ == 1
97+ assert_equal ( ransac_estimator .n_trials_ , 1 )
9598
9699
97100def test_ransac_stop_score ():
@@ -101,7 +104,7 @@ def test_ransac_stop_score():
101104 random_state = 0 )
102105 ransac_estimator .fit (X , y )
103106
104- assert ransac_estimator .n_trials_ == 1
107+ assert_equal ( ransac_estimator .n_trials_ , 1 )
105108
106109
107110def test_ransac_score ():
@@ -115,8 +118,8 @@ def test_ransac_score():
115118 residual_threshold = 0.5 , random_state = 0 )
116119 ransac_estimator .fit (X , y )
117120
118- assert ransac_estimator .score (X [2 :], y [2 :]) == 1
119- assert ransac_estimator .score (X [:2 ], y [:2 ]) < 1
121+ assert_equal ( ransac_estimator .score (X [2 :], y [2 :]), 1 )
122+ assert_less ( ransac_estimator .score (X [:2 ], y [:2 ]), 1 )
120123
121124
122125def test_ransac_predict ():
0 commit comments