Skip to content

Commit a72da02

Browse files
x0lagramfort
authored andcommitted
warning + log of prod -> sum of log
1 parent 1c28711 commit a72da02

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

sklearn/qda.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def fit(self, X, y, store_covariances=False, tol=1.0e-4):
124124
S2 = S ** 2
125125
if len(Xg) > 1:
126126
S2 /= len(Xg) - 1
127+
else:
128+
warnings.warn("Variance is null for one-element class")
127129
S2 = ((1 - self.reg_param) * S2) + self.reg_param
128130
if store_covariances:
129131
# cov = V * (S^2 / (n-1)) * V.T
@@ -147,7 +149,7 @@ def _decision_function(self, X):
147149
X2 = np.dot(Xm, R * (S ** (-0.5)))
148150
norm2.append(np.sum(X2 ** 2, 1))
149151
norm2 = np.array(norm2).T # shape = [len(X), n_classes]
150-
u = np.log([np.prod(s) for s in self.scalings_])
152+
u = np.asarray([np.sum(np.log(s)) for s in self.scalings_])
151153
return (-0.5 * (norm2 + u) + np.log(self.priors_))
152154

153155
def decision_function(self, X):

0 commit comments

Comments
 (0)