1212
1313from sklearn .feature_extraction .text import ENGLISH_STOP_WORDS
1414
15+ from sklearn .cross_validation import train_test_split
16+ from sklearn .cross_validation import cross_val_score
1517from sklearn .grid_search import GridSearchCV
1618from sklearn .pipeline import Pipeline
1719from sklearn .svm import LinearSVC
@@ -704,15 +706,13 @@ def test_vectorizer_inverse_transform():
704706def test_count_vectorizer_pipeline_grid_selection ():
705707 # raw documents
706708 data = JUNK_FOOD_DOCS + NOTJUNK_FOOD_DOCS
707- # simulate iterables
708- train_data = iter (data [1 :- 1 ])
709- test_data = iter ([data [0 ], data [- 1 ]])
710709
711710 # label junk food as -1, the others as +1
712- y = np .ones (len (data ))
713- y [:6 ] = - 1
714- y_train = y [1 :- 1 ]
715- y_test = np .array ([y [0 ], y [- 1 ]])
711+ target = [- 1 ] * len (JUNK_FOOD_DOCS ) + [1 ] * len (NOTJUNK_FOOD_DOCS )
712+
713+ # split the dataset for model development and final evaluation
714+ train_data , test_data , target_train , target_test = train_test_split (
715+ data , target , test_size = .2 , random_state = 0 )
716716
717717 pipeline = Pipeline ([('vect' , CountVectorizer ()),
718718 ('svc' , LinearSVC ())])
@@ -726,10 +726,10 @@ def test_count_vectorizer_pipeline_grid_selection():
726726 # classifier
727727 grid_search = GridSearchCV (pipeline , parameters , n_jobs = 1 )
728728
729- # cross-validation doesn't work if the length of the data is not known,
730- # hence use lists instead of iterators
731- pred = grid_search .fit (list ( train_data ), y_train ).predict (list ( test_data ) )
732- assert_array_equal (pred , y_test )
729+ # Check that the best model found by grid search is 100% correct on the
730+ # held out evaluation set.
731+ pred = grid_search .fit (train_data , target_train ).predict (test_data )
732+ assert_array_equal (pred , target_test )
733733
734734 # on this toy dataset bigram representation which is used in the last of
735735 # the grid_search is considered the best estimator since they all converge
@@ -742,15 +742,13 @@ def test_count_vectorizer_pipeline_grid_selection():
742742def test_vectorizer_pipeline_grid_selection ():
743743 # raw documents
744744 data = JUNK_FOOD_DOCS + NOTJUNK_FOOD_DOCS
745- # simulate iterables
746- train_data = iter (data [1 :- 1 ])
747- test_data = iter ([data [0 ], data [- 1 ]])
748745
749746 # label junk food as -1, the others as +1
750- y = np .ones (len (data ))
751- y [:6 ] = - 1
752- y_train = y [1 :- 1 ]
753- y_test = np .array ([y [0 ], y [- 1 ]])
747+ target = [- 1 ] * len (JUNK_FOOD_DOCS ) + [1 ] * len (NOTJUNK_FOOD_DOCS )
748+
749+ # split the dataset for model development and final evaluation
750+ train_data , test_data , target_train , target_test = train_test_split (
751+ data , target , test_size = .1 , random_state = 0 )
754752
755753 pipeline = Pipeline ([('vect' , TfidfVectorizer ()),
756754 ('svc' , LinearSVC ())])
@@ -765,10 +763,10 @@ def test_vectorizer_pipeline_grid_selection():
765763 # classifier
766764 grid_search = GridSearchCV (pipeline , parameters , n_jobs = 1 )
767765
768- # cross-validation doesn't work if the length of the data is not known,
769- # hence use lists instead of iterators
770- pred = grid_search .fit (list ( train_data ), y_train ).predict (list ( test_data ) )
771- assert_array_equal (pred , y_test )
766+ # Check that the best model found by grid search is 100% correct on the
767+ # held out evaluation set.
768+ pred = grid_search .fit (train_data , target_train ).predict (test_data )
769+ assert_array_equal (pred , target_test )
772770
773771 # on this toy dataset bigram representation which is used in the last of
774772 # the grid_search is considered the best estimator since they all converge
@@ -780,6 +778,21 @@ def test_vectorizer_pipeline_grid_selection():
780778 assert_false (best_vectorizer .fixed_vocabulary )
781779
782780
781+ def test_vectorizer_pipeline_cross_validation ():
782+ # raw documents
783+ data = JUNK_FOOD_DOCS + NOTJUNK_FOOD_DOCS
784+
785+ # label junk food as -1, the others as +1
786+ target = [- 1 ] * len (JUNK_FOOD_DOCS ) + [1 ] * len (NOTJUNK_FOOD_DOCS )
787+
788+
789+ pipeline = Pipeline ([('vect' , TfidfVectorizer ()),
790+ ('svc' , LinearSVC ())])
791+
792+ cv_scores = cross_val_score (pipeline , data , target , cv = 3 )
793+ assert_array_equal (cv_scores , [1. , 1. , 1. ])
794+
795+
783796def test_vectorizer_unicode ():
784797 # tests that the count vectorizer works with cyrillic.
785798 document = (
0 commit comments