44from sklearn .utils .testing import assert_array_almost_equal
55from sklearn .utils .testing import assert_true
66from sklearn .utils .testing import assert_greater
7+ from sklearn .utils .testing import assert_raises
78
89from sklearn import qda
910
2122X2 = np .array ([[- 3 , 0 ], [- 2 , 0 ], [- 1 , 0 ], [- 1 , 0 ], [0 , 0 ], [1 , 0 ], [1 , 0 ],
2223 [2 , 0 ], [3 , 0 ]])
2324
25+ # One element class
26+ y4 = np .array ([1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 2 ])
27+
28+ # Data with less samples in a class than n_features
29+ X5 = np .c_ [np .arange (8 ), np .zeros ((8 ,3 ))]
30+ y5 = np .array ([0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 ])
31+
2432
2533def test_qda ():
2634 """
@@ -47,6 +55,9 @@ def test_qda():
4755 # QDA shouldn't be able to separate those
4856 assert_true (np .any (y_pred3 != y3 ))
4957
58+ # Classes should have at least 2 elements
59+ assert_raises (ValueError , clf .fit , X , y4 )
60+
5061
5162def test_qda_priors ():
5263 clf = qda .QDA ()
@@ -92,3 +103,9 @@ def test_qda_regularization():
92103 clf = qda .QDA (reg_param = 0.01 )
93104 y_pred = clf .fit (X2 , y ).predict (X2 )
94105 assert_array_equal (y_pred , y )
106+
107+ # Case n_samples_in_a_class < n_features
108+ # (needs some stronger regularization, test is very singular)
109+ clf = qda .QDA (reg_param = 1e-1 )
110+ y_pred5 = clf .fit (X5 , y5 ).predict (X5 )
111+ assert_array_equal (y_pred5 , y5 )
0 commit comments