Skip to content

Commit 922d71a

Browse files
committed
FIX check parameter in LogisticRegression
1 parent 5029db7 commit 922d71a

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

sklearn/linear_model/logistic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,9 +1003,15 @@ def fit(self, X, y):
10031003
self : object
10041004
Returns self.
10051005
"""
1006-
if self.C < 0:
1006+
if not isinstance(self.C, numbers.Number) or self.C < 0:
10071007
raise ValueError("Penalty term must be positive; got (C=%r)"
10081008
% self.C)
1009+
if not isinstance(self.max_iter, numbers.Number) or self.max_iter < 0:
1010+
raise ValueError("Maximum number of iteration must be positive;"
1011+
" got (max_iter=%r)" % self.max_iter)
1012+
if not isinstance(self.tol, numbers.Number) or self.tol < 0:
1013+
raise ValueError("Tolerance for stopping criteria must be "
1014+
"positive; got (tol=%r)" % self.tol)
10091015

10101016
X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64, order="C")
10111017
self.classes_ = np.unique(y)

sklearn/linear_model/tests/test_logistic.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sklearn.utils.testing import assert_almost_equal
66
from sklearn.utils.testing import assert_array_equal
77
from sklearn.utils.testing import assert_array_almost_equal
8+
from sklearn.utils.testing import assert_raises_regexp
89
from sklearn.utils.testing import assert_equal
910
from sklearn.utils.testing import assert_greater
1011
from sklearn.utils.testing import assert_raises
@@ -67,7 +68,23 @@ def test_predict_2_classes():
6768

6869
def test_error():
6970
# Test for appropriate exception on errors
70-
assert_raises(ValueError, LogisticRegression(C=-1).fit, X, Y1)
71+
msg = "Penalty term must be positive"
72+
assert_raises_regexp(ValueError, msg,
73+
LogisticRegression(C=-1).fit, X, Y1)
74+
assert_raises_regexp(ValueError, msg,
75+
LogisticRegression(C="test").fit, X, Y1)
76+
77+
msg = "Tolerance for stopping criteria must be positive"
78+
assert_raises_regexp(ValueError, msg,
79+
LogisticRegression(tol=-1).fit, X, Y1)
80+
assert_raises_regexp(ValueError, msg,
81+
LogisticRegression(tol="test").fit, X, Y1)
82+
83+
msg = "Maximum number of iteration must be positive"
84+
assert_raises_regexp(ValueError, msg,
85+
LogisticRegression(max_iter=-1).fit, X, Y1)
86+
assert_raises_regexp(ValueError, msg,
87+
LogisticRegression(max_iter="test").fit, X, Y1)
7188

7289

7390
def test_predict_3_classes():

0 commit comments

Comments
 (0)