Skip to content

Commit 36932d7

Browse files
x0lagramfort
authored andcommitted
tests
1 parent 17ea0db commit 36932d7

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

sklearn/tests/test_qda.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from sklearn.utils.testing import assert_array_almost_equal
55
from sklearn.utils.testing import assert_true
66
from sklearn.utils.testing import assert_greater
7+
from sklearn.utils.testing import assert_raises
78

89
from sklearn import qda
910

@@ -21,6 +22,13 @@
2122
X2 = np.array([[-3, 0], [-2, 0], [-1, 0], [-1, 0], [0, 0], [1, 0], [1, 0],
2223
[2, 0], [3, 0]])
2324

25+
# One element class
26+
y4 = np.array([1, 1, 1, 1, 1, 1, 1, 1, 2])
27+
28+
# Data with less samples in a class than n_features
29+
X5 = np.c_[np.arange(8), np.zeros((8,3))]
30+
y5 = np.array([0, 0, 0, 0, 0, 1, 1, 1])
31+
2432

2533
def test_qda():
2634
"""
@@ -47,6 +55,9 @@ def test_qda():
4755
# QDA shouldn't be able to separate those
4856
assert_true(np.any(y_pred3 != y3))
4957

58+
# Classes should have at least 2 elements
59+
assert_raises(ValueError, clf.fit, X, y4)
60+
5061

5162
def test_qda_priors():
5263
clf = qda.QDA()
@@ -92,3 +103,9 @@ def test_qda_regularization():
92103
clf = qda.QDA(reg_param=0.01)
93104
y_pred = clf.fit(X2, y).predict(X2)
94105
assert_array_equal(y_pred, y)
106+
107+
# Case n_samples_in_a_class < n_features
108+
# (needs some stronger regularization, test is very singular)
109+
clf = qda.QDA(reg_param=1e-1)
110+
y_pred5 = clf.fit(X5, y5).predict(X5)
111+
assert_array_equal(y_pred5, y5)

0 commit comments

Comments
 (0)