4242class MockListClassifier (BaseEstimator ):
4343 """Dummy classifier to test the cross-validation.
4444
45- Checks that GridSearchCV didn't convert X to array.
45+ Checks that GridSearchCV didn't convert X or Y to array.
4646 """
47- def __init__ (self , foo_param = 0 ):
47+ def __init__ (self , foo_param = 0 , check_y = False , check_X = True ):
4848 self .foo_param = foo_param
49+ self .check_y = check_y
50+ self .check_X = check_X
4951
5052 def fit (self , X , Y ):
5153 assert_true (len (X ) == len (Y ))
52- assert_true (isinstance (X , list ))
54+ if self .check_X :
55+ assert_true (isinstance (X , list ))
56+ if self .check_y :
57+ assert_true (isinstance (Y , list ))
58+
5359 return self
5460
5561 def predict (self , T ):
@@ -513,6 +519,9 @@ def test_cross_val_score():
513519 clf = MockListClassifier ()
514520 scores = cval .cross_val_score (clf , X .tolist (), y .tolist ())
515521
522+ clf = MockListClassifier (check_X = False , check_y = True )
523+ scores = cval .cross_val_score (clf , X , y .tolist ())
524+
516525 assert_raises (ValueError , cval .cross_val_score , clf , X , y ,
517526 scoring = "sklearn" )
518527
@@ -596,7 +605,7 @@ def test_train_test_split():
596605 X = np .arange (100 ).reshape ((10 , 10 ))
597606 X_s = coo_matrix (X )
598607 y = range (10 )
599- split = cval .train_test_split (X , X_s , y )
608+ split = cval .train_test_split (X , X_s , y , allow_lists = False )
600609 X_train , X_test , X_s_train , X_s_test , y_train , y_test = split
601610 assert_array_equal (X_train , X_s_train .toarray ())
602611 assert_array_equal (X_test , X_s_test .toarray ())
@@ -606,6 +615,11 @@ def test_train_test_split():
606615 X_train , X_test , y_train , y_test = split
607616 assert_equal (len (y_test ), len (y_train ))
608617
618+ split = cval .train_test_split (X , X_s , y )
619+ X_train , X_test , X_s_train , X_s_test , y_train , y_test = split
620+ assert_true (isinstance (y_train , list ))
621+ assert_true (isinstance (y_test , list ))
622+
609623
610624def test_cross_val_score_with_score_func_classification ():
611625 iris = load_iris ()
@@ -911,7 +925,7 @@ def test_train_test_split_allow_nans():
911925 X = np .arange (200 , dtype = np .float64 ).reshape (10 , - 1 )
912926 X [2 , :] = np .nan
913927 y = np .repeat ([0 , 1 ], X .shape [0 ]/ 2 )
914- split = cval .train_test_split (X , y , test_size = 0.2 , random_state = 42 )
928+ cval .train_test_split (X , y , test_size = 0.2 , random_state = 42 )
915929
916930
917931def test_permutation_test_score_allow_nans ():
0 commit comments