22import scipy .sparse as sp
33
44from numpy .testing import assert_array_equal
5+ from numpy .testing import assert_array_almost_equal
56import nose
67from nose .tools import assert_equal , assert_raises , raises
78
1718iris = datasets .load_iris ()
1819
1920
21+ def check_predictions (clf , X , y ):
22+ """Check that the model is able to fit the classification data"""
23+ n_samples = len (y )
24+ classes = np .unique (y )
25+ n_classes = classes .shape [0 ]
26+
27+ predicted = clf .fit (X , y ).predict (X )
28+ assert_array_equal (clf .classes_ , classes )
29+
30+ assert_equal (predicted .shape , (n_samples ,))
31+ assert_array_equal (predicted , y )
32+
33+ probabilities = clf .predict_proba (X )
34+ assert_equal (probabilities .shape , (n_samples , n_classes ))
35+ assert_array_almost_equal (probabilities .sum (axis = 1 ), np .ones (n_samples ))
36+ assert_array_equal (probabilities .argmax (axis = 1 ), y )
37+
38+
2039def test_predict_2_classes ():
2140 """Simple sanity check on a 2 classes dataset
2241
2342 Make sure it predicts the correct result on simple datasets.
2443 """
25- clf = logistic .LogisticRegression ().fit (X , Y1 )
26- assert_array_equal (clf .predict (X ), Y1 )
27- assert_array_equal (clf .predict_proba (X ).argmax (axis = 1 ), Y1 )
44+ check_predictions (logistic .LogisticRegression (), X , Y1 )
45+ check_predictions (logistic .LogisticRegression (), X_sp , Y1 )
2846
29- clf = logistic .LogisticRegression ().fit (X_sp , Y1 )
30- assert_array_equal (clf .predict (X_sp ), Y1 )
31- assert_array_equal (clf .predict_proba (X_sp ).argmax (axis = 1 ), Y1 )
47+ check_predictions (logistic .LogisticRegression (C = 100 ), X , Y1 )
48+ check_predictions (logistic .LogisticRegression (C = 100 ), X_sp , Y1 )
3249
33- clf = logistic .LogisticRegression (C = 100 ).fit (X , Y1 )
34- assert_array_equal (clf .predict (X ), Y1 )
35- assert_array_equal (clf .predict_proba (X ).argmax (axis = 1 ), Y1 )
36-
37- clf = logistic .LogisticRegression (C = 100 ).fit (X_sp , Y1 )
38- assert_array_equal (clf .predict (X_sp ), Y1 )
39- assert_array_equal (clf .predict_proba (X_sp ).argmax (axis = 1 ), Y1 )
40-
41- clf = logistic .LogisticRegression (fit_intercept = False ).fit (X , Y1 )
42- assert_array_equal (clf .predict (X ), Y1 )
43- assert_array_equal (clf .predict_proba (X ).argmax (axis = 1 ), Y1 )
44-
45- clf = logistic .LogisticRegression (fit_intercept = False ).fit (X_sp , Y1 )
46- assert_array_equal (clf .predict (X_sp ), Y1 )
47- assert_array_equal (clf .predict_proba (X_sp ).argmax (axis = 1 ), Y1 )
50+ check_predictions (logistic .LogisticRegression (fit_intercept = False ),
51+ X , Y1 )
52+ check_predictions (logistic .LogisticRegression (fit_intercept = False ),
53+ X_sp , Y1 )
4854
4955
5056def test_error ():
@@ -53,26 +59,25 @@ def test_error():
5359
5460
5561def test_predict_3_classes ():
56- clf = logistic .LogisticRegression (C = 10 ).fit (X , Y2 )
57- assert_array_equal (clf .predict (X ), Y2 )
58- assert_array_equal (clf .predict_proba (X ).argmax (axis = 1 ), Y2 )
59-
60- clf = logistic .LogisticRegression (C = 10 ).fit (X_sp , Y2 )
61- assert_array_equal (clf .predict (X_sp ), Y2 )
62- assert_array_equal (clf .predict_proba (X_sp ).argmax (axis = 1 ), Y2 )
62+ check_predictions (logistic .LogisticRegression (C = 10 ), X , Y2 )
63+ check_predictions (logistic .LogisticRegression (C = 10 ), X_sp , Y2 )
6364
6465
6566def test_predict_iris ():
6667 """Test logisic regression with the iris dataset"""
68+ n_samples , n_features = iris .data .shape
6769
6870 target = iris .target_names [iris .target ]
6971 clf = logistic .LogisticRegression (C = len (iris .data )).fit (iris .data , target )
70- assert_equal ( set (target ), set ( clf .classes_ ) )
72+ assert_array_equal ( np . unique (target ), clf .classes_ )
7173
7274 pred = clf .predict (iris .data )
7375 assert_greater (np .mean (pred == target ), .95 )
7476
75- pred = iris .target_names [clf .predict_proba (iris .data ).argmax (axis = 1 )]
77+ probabilities = clf .predict_proba (iris .data )
78+ assert_array_almost_equal (probabilities .sum (axis = 1 ), np .ones (n_samples ))
79+
80+ pred = iris .target_names [probabilities .argmax (axis = 1 )]
7681 assert_greater (np .mean (pred == target ), .95 )
7782
7883
0 commit comments