Skip to content

Commit 17ea0db

Browse files
x0lagramfort
authored andcommitted
warning -> error
though we can check if reg_param > 0 and then use spherical covariance, let's keep things simple
1 parent a72da02 commit 17ea0db

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

sklearn/qda.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,17 +115,16 @@ def fit(self, X, y, store_covariances=False, tol=1.0e-4):
115115
Xg = X[y == ind, :]
116116
meang = Xg.mean(0)
117117
means.append(meang)
118+
if len(Xg) == 1:
119+
raise ValueError('y has only 1 sample in class %s, covariance '
120+
'is ill defined.' % str(self.classes_[ind]))
118121
Xgc = Xg - meang
119122
# Xgc = U * S * V.T
120123
U, S, Vt = np.linalg.svd(Xgc, full_matrices=False)
121124
rank = np.sum(S > tol)
122125
if rank < n_features:
123126
warnings.warn("Variables are collinear")
124-
S2 = S ** 2
125-
if len(Xg) > 1:
126-
S2 /= len(Xg) - 1
127-
else:
128-
warnings.warn("Variance is null for one-element class")
127+
S2 = (S ** 2) / (len(Xg) - 1)
129128
S2 = ((1 - self.reg_param) * S2) + self.reg_param
130129
if store_covariances:
131130
# cov = V * (S^2 / (n-1)) * V.T

0 commit comments

Comments
 (0)