Skip to content

Commit 244e8a6

Browse files
committed
Merge pull request scikit-learn#1187 from ogrisel/bugfix-logistic-ovr-probabilities
MRG: FIX: wrong probabilities for OvR LogisticRegression
2 parents aa3641f + 531dcb0 commit 244e8a6

File tree

3 files changed

+58
-37
lines changed

3 files changed

+58
-37
lines changed

sklearn/linear_model/logistic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def predict_proba(self, X):
122122
return np.vstack([1 - prob, prob]).T
123123
else:
124124
# OvR, not softmax, like Liblinear's predict_probability
125-
prob /= prob.sum(axis=0)
125+
prob /= prob.sum(axis=1).reshape((prob.shape[0], -1))
126126
return prob
127127

128128
def predict_log_proba(self, X):

sklearn/linear_model/tests/test_logistic.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import scipy.sparse as sp
33

44
from numpy.testing import assert_array_equal
5+
from numpy.testing import assert_array_almost_equal
56
import nose
67
from nose.tools import assert_equal, assert_raises, raises
78

@@ -17,34 +18,39 @@
1718
iris = datasets.load_iris()
1819

1920

21+
def check_predictions(clf, X, y):
22+
"""Check that the model is able to fit the classification data"""
23+
n_samples = len(y)
24+
classes = np.unique(y)
25+
n_classes = classes.shape[0]
26+
27+
predicted = clf.fit(X, y).predict(X)
28+
assert_array_equal(clf.classes_, classes)
29+
30+
assert_equal(predicted.shape, (n_samples,))
31+
assert_array_equal(predicted, y)
32+
33+
probabilities = clf.predict_proba(X)
34+
assert_equal(probabilities.shape, (n_samples, n_classes))
35+
assert_array_almost_equal(probabilities.sum(axis=1), np.ones(n_samples))
36+
assert_array_equal(probabilities.argmax(axis=1), y)
37+
38+
2039
def test_predict_2_classes():
2140
"""Simple sanity check on a 2 classes dataset
2241
2342
Make sure it predicts the correct result on simple datasets.
2443
"""
25-
clf = logistic.LogisticRegression().fit(X, Y1)
26-
assert_array_equal(clf.predict(X), Y1)
27-
assert_array_equal(clf.predict_proba(X).argmax(axis=1), Y1)
44+
check_predictions(logistic.LogisticRegression(), X, Y1)
45+
check_predictions(logistic.LogisticRegression(), X_sp, Y1)
2846

29-
clf = logistic.LogisticRegression().fit(X_sp, Y1)
30-
assert_array_equal(clf.predict(X_sp), Y1)
31-
assert_array_equal(clf.predict_proba(X_sp).argmax(axis=1), Y1)
47+
check_predictions(logistic.LogisticRegression(C=100), X, Y1)
48+
check_predictions(logistic.LogisticRegression(C=100), X_sp, Y1)
3249

33-
clf = logistic.LogisticRegression(C=100).fit(X, Y1)
34-
assert_array_equal(clf.predict(X), Y1)
35-
assert_array_equal(clf.predict_proba(X).argmax(axis=1), Y1)
36-
37-
clf = logistic.LogisticRegression(C=100).fit(X_sp, Y1)
38-
assert_array_equal(clf.predict(X_sp), Y1)
39-
assert_array_equal(clf.predict_proba(X_sp).argmax(axis=1), Y1)
40-
41-
clf = logistic.LogisticRegression(fit_intercept=False).fit(X, Y1)
42-
assert_array_equal(clf.predict(X), Y1)
43-
assert_array_equal(clf.predict_proba(X).argmax(axis=1), Y1)
44-
45-
clf = logistic.LogisticRegression(fit_intercept=False).fit(X_sp, Y1)
46-
assert_array_equal(clf.predict(X_sp), Y1)
47-
assert_array_equal(clf.predict_proba(X_sp).argmax(axis=1), Y1)
50+
check_predictions(logistic.LogisticRegression(fit_intercept=False),
51+
X, Y1)
52+
check_predictions(logistic.LogisticRegression(fit_intercept=False),
53+
X_sp, Y1)
4854

4955

5056
def test_error():
@@ -53,26 +59,25 @@ def test_error():
5359

5460

5561
def test_predict_3_classes():
56-
clf = logistic.LogisticRegression(C=10).fit(X, Y2)
57-
assert_array_equal(clf.predict(X), Y2)
58-
assert_array_equal(clf.predict_proba(X).argmax(axis=1), Y2)
59-
60-
clf = logistic.LogisticRegression(C=10).fit(X_sp, Y2)
61-
assert_array_equal(clf.predict(X_sp), Y2)
62-
assert_array_equal(clf.predict_proba(X_sp).argmax(axis=1), Y2)
62+
check_predictions(logistic.LogisticRegression(C=10), X, Y2)
63+
check_predictions(logistic.LogisticRegression(C=10), X_sp, Y2)
6364

6465

6566
def test_predict_iris():
6667
"""Test logisic regression with the iris dataset"""
68+
n_samples, n_features = iris.data.shape
6769

6870
target = iris.target_names[iris.target]
6971
clf = logistic.LogisticRegression(C=len(iris.data)).fit(iris.data, target)
70-
assert_equal(set(target), set(clf.classes_))
72+
assert_array_equal(np.unique(target), clf.classes_)
7173

7274
pred = clf.predict(iris.data)
7375
assert_greater(np.mean(pred == target), .95)
7476

75-
pred = iris.target_names[clf.predict_proba(iris.data).argmax(axis=1)]
77+
probabilities = clf.predict_proba(iris.data)
78+
assert_array_almost_equal(probabilities.sum(axis=1), np.ones(n_samples))
79+
80+
pred = iris.target_names[probabilities.argmax(axis=1)]
7681
assert_greater(np.mean(pred == target), .95)
7782

7883

sklearn/tests/test_common.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,8 @@ def test_classifiers_train():
314314
X_b = X_m[y_m != 2]
315315
for (X, y) in [(X_m, y_m), (X_b, y_b)]:
316316
# do it once with binary, once with multiclass
317-
n_labels = len(np.unique(y))
317+
classes = np.unique(y)
318+
n_classes = len(classes)
318319
n_samples, n_features = X.shape
319320
for name, Clf in classifiers:
320321
if Clf in dont_test or Clf in meta_estimators:
@@ -341,13 +342,13 @@ def test_classifiers_train():
341342
try:
342343
# decision_function agrees with predict:
343344
decision = clf.decision_function(X)
344-
if n_labels is 2:
345+
if n_classes is 2:
345346
assert_equal(decision.ravel().shape, (n_samples,))
346347
dec_pred = (decision.ravel() > 0).astype(np.int)
347348
assert_array_equal(dec_pred, y_pred)
348-
if n_labels is 3 and not isinstance(clf, BaseLibSVM):
349+
if n_classes is 3 and not isinstance(clf, BaseLibSVM):
349350
# 1on1 of LibSVM works differently
350-
assert_equal(decision.shape, (n_samples, n_labels))
351+
assert_equal(decision.shape, (n_samples, n_classes))
351352
assert_array_equal(np.argmax(decision, axis=1), y_pred)
352353

353354
# raises error on malformed input
@@ -360,15 +361,30 @@ def test_classifiers_train():
360361
try:
361362
# predict_proba agrees with predict:
362363
y_prob = clf.predict_proba(X)
363-
assert_equal(y_prob.shape, (n_samples, n_labels))
364+
assert_equal(y_prob.shape, (n_samples, n_classes))
365+
assert_array_equal(np.argmax(y_prob, axis=1), y_pred)
366+
# check that probas for all classes sum to one
367+
assert_array_almost_equal(
368+
np.sum(y_prob, axis=1), np.ones(n_samples))
364369
# raises error on malformed input
365370
assert_raises(ValueError, clf.predict_proba, X.T)
366-
assert_array_equal(np.argmax(y_prob, axis=1), y_pred)
367371
# raises error on malformed input for predict_proba
368372
assert_raises(ValueError, clf.predict_proba, X.T)
369373
except NotImplementedError:
370374
pass
371375

376+
if hasattr(clf, "classes_"):
377+
if hasattr(clf, "n_outputs_"):
378+
assert_equal(clf.n_outputs_, 1)
379+
assert_array_equal(
380+
clf.classes_, [classes],
381+
"Unexpected classes_ attribute for %r" % clf)
382+
else:
383+
# flat classes array: XXX inconsistent
384+
assert_array_equal(
385+
clf.classes_, classes,
386+
"Unexpected classes_ attribute for %r" % clf)
387+
372388

373389
def test_classifiers_classes():
374390
# test if classifiers can cope with non-consecutive classes

0 commit comments

Comments
 (0)