Skip to content

Commit cc9dbac

Browse files
committed
Adding test for PR scikit-learn#6900
1 parent 2242b1b commit cc9dbac

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

sklearn/preprocessing/tests/test_data.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@
5252
from sklearn.preprocessing.data import PolynomialFeatures
5353
from sklearn.exceptions import DataConversionWarning
5454

55+
from sklearn.pipeline import Pipeline
56+
from sklearn.cross_validation import cross_val_score
57+
from sklearn.cross_validation import LeaveOneOut
58+
from sklearn.svm import SVR
59+
5560
from sklearn import datasets
5661

5762
iris = datasets.load_iris()
@@ -1369,6 +1374,23 @@ def test_center_kernel():
13691374
K_pred_centered2 = centerer.transform(K_pred)
13701375
assert_array_almost_equal(K_pred_centered, K_pred_centered2)
13711376

1377+
def test_cv_pipeline_precomputed():
1378+
"""Cross-validate a regression on four coplanar points with the same
1379+
value. Use precomputed kernel to ensure Pipeline with KernelCenterer
1380+
is treated as a _pairwise operation."""
1381+
X = np.array([[3,0,0],[0,3,0],[0,0,3],[1,1,1]])
1382+
y = np.ones((4,))
1383+
K = X.dot(X.T)
1384+
kcent = KernelCenterer()
1385+
pipeline = Pipeline([("kernel_centerer", kcent), ("svr", SVR())])
1386+
1387+
# did the pipeline set the _pairwise attribute?
1388+
assert_true(pipeline._pairwise)
1389+
1390+
# test cross-validation, score should be almost perfect
1391+
score = cross_val_score(pipeline,K,y,cv=LeaveOneOut(4))
1392+
assert_array_almost_equal(score, np.ones_like(score))
1393+
13721394

13731395
def test_fit_transform():
13741396
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)