Skip to content

Commit 4bf824c

Browse files
ogriselamueller
authored andcommitted
TST non-regression test for CV on text pipelines
1 parent e242ebc commit 4bf824c

File tree

1 file changed

+35
-22
lines changed

1 file changed

+35
-22
lines changed

sklearn/feature_extraction/tests/test_text.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
1414

15+
from sklearn.cross_validation import train_test_split
16+
from sklearn.cross_validation import cross_val_score
1517
from sklearn.grid_search import GridSearchCV
1618
from sklearn.pipeline import Pipeline
1719
from sklearn.svm import LinearSVC
@@ -704,15 +706,13 @@ def test_vectorizer_inverse_transform():
704706
def test_count_vectorizer_pipeline_grid_selection():
705707
# raw documents
706708
data = JUNK_FOOD_DOCS + NOTJUNK_FOOD_DOCS
707-
# simulate iterables
708-
train_data = iter(data[1:-1])
709-
test_data = iter([data[0], data[-1]])
710709

711710
# label junk food as -1, the others as +1
712-
y = np.ones(len(data))
713-
y[:6] = -1
714-
y_train = y[1:-1]
715-
y_test = np.array([y[0], y[-1]])
711+
target = [-1] * len(JUNK_FOOD_DOCS) + [1] * len(NOTJUNK_FOOD_DOCS)
712+
713+
# split the dataset for model development and final evaluation
714+
train_data, test_data, target_train, target_test = train_test_split(
715+
data, target, test_size=.2, random_state=0)
716716

717717
pipeline = Pipeline([('vect', CountVectorizer()),
718718
('svc', LinearSVC())])
@@ -726,10 +726,10 @@ def test_count_vectorizer_pipeline_grid_selection():
726726
# classifier
727727
grid_search = GridSearchCV(pipeline, parameters, n_jobs=1)
728728

729-
# cross-validation doesn't work if the length of the data is not known,
730-
# hence use lists instead of iterators
731-
pred = grid_search.fit(list(train_data), y_train).predict(list(test_data))
732-
assert_array_equal(pred, y_test)
729+
# Check that the best model found by grid search is 100% correct on the
730+
# held out evaluation set.
731+
pred = grid_search.fit(train_data, target_train).predict(test_data)
732+
assert_array_equal(pred, target_test)
733733

734734
# on this toy dataset bigram representation which is used in the last of
735735
# the grid_search is considered the best estimator since they all converge
@@ -742,15 +742,13 @@ def test_count_vectorizer_pipeline_grid_selection():
742742
def test_vectorizer_pipeline_grid_selection():
743743
# raw documents
744744
data = JUNK_FOOD_DOCS + NOTJUNK_FOOD_DOCS
745-
# simulate iterables
746-
train_data = iter(data[1:-1])
747-
test_data = iter([data[0], data[-1]])
748745

749746
# label junk food as -1, the others as +1
750-
y = np.ones(len(data))
751-
y[:6] = -1
752-
y_train = y[1:-1]
753-
y_test = np.array([y[0], y[-1]])
747+
target = [-1] * len(JUNK_FOOD_DOCS) + [1] * len(NOTJUNK_FOOD_DOCS)
748+
749+
# split the dataset for model development and final evaluation
750+
train_data, test_data, target_train, target_test = train_test_split(
751+
data, target, test_size=.1, random_state=0)
754752

755753
pipeline = Pipeline([('vect', TfidfVectorizer()),
756754
('svc', LinearSVC())])
@@ -765,10 +763,10 @@ def test_vectorizer_pipeline_grid_selection():
765763
# classifier
766764
grid_search = GridSearchCV(pipeline, parameters, n_jobs=1)
767765

768-
# cross-validation doesn't work if the length of the data is not known,
769-
# hence use lists instead of iterators
770-
pred = grid_search.fit(list(train_data), y_train).predict(list(test_data))
771-
assert_array_equal(pred, y_test)
766+
# Check that the best model found by grid search is 100% correct on the
767+
# held out evaluation set.
768+
pred = grid_search.fit(train_data, target_train).predict(test_data)
769+
assert_array_equal(pred, target_test)
772770

773771
# on this toy dataset bigram representation which is used in the last of
774772
# the grid_search is considered the best estimator since they all converge
@@ -780,6 +778,21 @@ def test_vectorizer_pipeline_grid_selection():
780778
assert_false(best_vectorizer.fixed_vocabulary)
781779

782780

781+
def test_vectorizer_pipeline_cross_validation():
782+
# raw documents
783+
data = JUNK_FOOD_DOCS + NOTJUNK_FOOD_DOCS
784+
785+
# label junk food as -1, the others as +1
786+
target = [-1] * len(JUNK_FOOD_DOCS) + [1] * len(NOTJUNK_FOOD_DOCS)
787+
788+
789+
pipeline = Pipeline([('vect', TfidfVectorizer()),
790+
('svc', LinearSVC())])
791+
792+
cv_scores = cross_val_score(pipeline, data, target, cv=3)
793+
assert_array_equal(cv_scores, [1., 1., 1.])
794+
795+
783796
def test_vectorizer_unicode():
784797
# tests that the count vectorizer works with cyrillic.
785798
document = (

0 commit comments

Comments
 (0)