Skip to content

Commit 355dc7c

Browse files
committed
ENH improve parameter check in LogisticRegression
1 parent 0b07536 commit 355dc7c

File tree

2 files changed

+77
-54
lines changed

2 files changed

+77
-54
lines changed

sklearn/linear_model/logistic.py

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,28 @@ def hessp(v):
391391
return grad, hessp
392392

393393

394+
def _check_solver_option(solver, multi_class, penalty, dual):
395+
if solver not in ['liblinear', 'newton-cg', 'lbfgs']:
396+
raise ValueError("Logistic Regression supports only liblinear,"
397+
" newton-cg and lbfgs solvers, got %s" % solver)
398+
399+
if multi_class not in ['multinomial', 'ovr']:
400+
raise ValueError("multi_class should be either multinomial or "
401+
"ovr, got %s" % multi_class)
402+
403+
if multi_class == 'multinomial' and solver == 'liblinear':
404+
raise ValueError("Solver %s does not support "
405+
"a multinomial backend." % solver)
406+
407+
if solver != 'liblinear':
408+
if penalty != 'l2':
409+
raise ValueError("Solver %s supports only l2 penalties, "
410+
"got %s penalty." % (solver, penalty))
411+
if dual:
412+
raise ValueError("Solver %s supports only "
413+
"dual=False, got dual=%s" % (solver, dual))
414+
415+
394416
def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
395417
max_iter=100, tol=1e-4, verbose=0,
396418
solver='lbfgs', coef=None, copy=True,
@@ -501,25 +523,8 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
501523
if isinstance(Cs, numbers.Integral):
502524
Cs = np.logspace(-4, 4, Cs)
503525

504-
if multi_class not in ['multinomial', 'ovr']:
505-
raise ValueError("multi_class can be either 'multinomial' or 'ovr'"
506-
"got %s" % multi_class)
507-
508-
if solver not in ['liblinear', 'newton-cg', 'lbfgs']:
509-
raise ValueError("Logistic Regression supports only liblinear,"
510-
" newton-cg and lbfgs solvers. got %s" % solver)
511-
512-
if multi_class == 'multinomial' and solver == 'liblinear':
513-
raise ValueError("Solver %s cannot solve problems with "
514-
"a multinomial backend." % solver)
526+
_check_solver_option(solver, multi_class, penalty, dual)
515527

516-
if solver != 'liblinear':
517-
if penalty != 'l2':
518-
raise ValueError("newton-cg and lbfgs solvers support only "
519-
"l2 penalties, got %s penalty." % penalty)
520-
if dual:
521-
raise ValueError("newton-cg and lbfgs solvers support only "
522-
"dual=False, got dual=%s" % dual)
523528
# Preprocessing.
524529
X = check_array(X, accept_sparse='csr', dtype=np.float64)
525530
y = check_array(y, ensure_2d=False, copy=copy, dtype=None)
@@ -781,6 +786,7 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
781786
scores : ndarray, shape (n_cs,)
782787
Scores obtained for each Cs.
783788
"""
789+
_check_solver_option(solver, multi_class, penalty, dual)
784790

785791
log_reg = LogisticRegression(fit_intercept=fit_intercept)
786792

@@ -1015,18 +1021,9 @@ def fit(self, X, y):
10151021

10161022
X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64, order="C")
10171023
self.classes_ = np.unique(y)
1018-
if self.solver not in ['liblinear', 'newton-cg', 'lbfgs']:
1019-
raise ValueError(
1020-
"Logistic Regression supports only liblinear, newton-cg and "
1021-
"lbfgs solvers, Got solver=%s" % self.solver
1022-
)
10231024

1024-
if self.solver == 'liblinear' and self.multi_class == 'multinomial':
1025-
raise ValueError("Solver %s does not support a multinomial "
1026-
"backend." % self.solver)
1027-
if self.multi_class not in ['ovr', 'multinomial']:
1028-
raise ValueError("multi_class should be either ovr or multinomial "
1029-
"got %s" % self.multi_class)
1025+
_check_solver_option(self.solver, self.multi_class, self.penalty,
1026+
self.dual)
10301027

10311028
if self.solver == 'liblinear':
10321029
self.coef_, self.intercept_, self.n_iter_ = _fit_liblinear(
@@ -1308,22 +1305,19 @@ def fit(self, X, y):
13081305
self : object
13091306
Returns self.
13101307
"""
1311-
if self.solver != 'liblinear':
1312-
if self.penalty != 'l2':
1313-
raise ValueError("newton-cg and lbfgs solvers support only "
1314-
"l2 penalties.")
1315-
if self.dual:
1316-
raise ValueError("newton-cg and lbfgs solvers support only "
1317-
"the primal form.")
1308+
_check_solver_option(self.solver, self.multi_class, self.penalty,
1309+
self.dual)
1310+
1311+
if not isinstance(self.max_iter, numbers.Number) or self.max_iter < 0:
1312+
raise ValueError("Maximum number of iteration must be positive;"
1313+
" got (max_iter=%r)" % self.max_iter)
1314+
if not isinstance(self.tol, numbers.Number) or self.tol < 0:
1315+
raise ValueError("Tolerance for stopping criteria must be "
1316+
"positive; got (tol=%r)" % self.tol)
13181317

13191318
X = check_array(X, accept_sparse='csr', dtype=np.float64)
13201319
y = check_array(y, ensure_2d=False, dtype=None)
13211320

1322-
if self.multi_class not in ['ovr', 'multinomial']:
1323-
raise ValueError("multi_class backend should be either "
1324-
"'ovr' or 'multinomial'"
1325-
" got %s" % self.multi_class)
1326-
13271321
if y.ndim == 2 and y.shape[1] == 1:
13281322
warnings.warn(
13291323
"A column-vector y was passed when a 1d array was"

sklearn/linear_model/tests/test_logistic.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from sklearn.utils.testing import assert_almost_equal
66
from sklearn.utils.testing import assert_array_equal
77
from sklearn.utils.testing import assert_array_almost_equal
8-
from sklearn.utils.testing import assert_raises_regexp
98
from sklearn.utils.testing import assert_equal
109
from sklearn.utils.testing import assert_greater
1110
from sklearn.utils.testing import assert_raises
@@ -69,22 +68,19 @@ def test_predict_2_classes():
6968
def test_error():
7069
# Test for appropriate exception on errors
7170
msg = "Penalty term must be positive"
72-
assert_raises_regexp(ValueError, msg,
71+
assert_raise_message(ValueError, msg,
7372
LogisticRegression(C=-1).fit, X, Y1)
74-
assert_raises_regexp(ValueError, msg,
73+
assert_raise_message(ValueError, msg,
7574
LogisticRegression(C="test").fit, X, Y1)
7675

77-
msg = "Tolerance for stopping criteria must be positive"
78-
assert_raises_regexp(ValueError, msg,
79-
LogisticRegression(tol=-1).fit, X, Y1)
80-
assert_raises_regexp(ValueError, msg,
81-
LogisticRegression(tol="test").fit, X, Y1)
76+
for LR in [LogisticRegression, LogisticRegressionCV]:
77+
msg = "Tolerance for stopping criteria must be positive"
78+
assert_raise_message(ValueError, msg, LR(tol=-1).fit, X, Y1)
79+
assert_raise_message(ValueError, msg, LR(tol="test").fit, X, Y1)
8280

83-
msg = "Maximum number of iteration must be positive"
84-
assert_raises_regexp(ValueError, msg,
85-
LogisticRegression(max_iter=-1).fit, X, Y1)
86-
assert_raises_regexp(ValueError, msg,
87-
LogisticRegression(max_iter="test").fit, X, Y1)
81+
msg = "Maximum number of iteration must be positive"
82+
assert_raise_message(ValueError, msg, LR(max_iter=-1).fit, X, Y1)
83+
assert_raise_message(ValueError, msg, LR(max_iter="test").fit, X, Y1)
8884

8985

9086
def test_predict_3_classes():
@@ -126,6 +122,39 @@ def test_multinomial_validation():
126122
assert_raises(ValueError, lr.fit, [[0, 1], [1, 0]], [0, 1])
127123

128124

125+
def test_check_solver_option():
126+
X, y = iris.data, iris.target
127+
for LR in [LogisticRegression, LogisticRegressionCV]:
128+
129+
msg = ("Logistic Regression supports only liblinear, newton-cg and"
130+
" lbfgs solvers, got wrong_name")
131+
lr = LR(solver="wrong_name")
132+
assert_raise_message(ValueError, msg, lr.fit, X, y)
133+
134+
msg = "multi_class should be either multinomial or ovr, got wrong_name"
135+
lr = LR(solver='newton-cg', multi_class="wrong_name")
136+
assert_raise_message(ValueError, msg, lr.fit, X, y)
137+
138+
# all solver except 'newton-cg' and 'lfbgs'
139+
for solver in ['liblinear']:
140+
msg = ("Solver %s does not support a multinomial backend." %
141+
solver)
142+
lr = LR(solver=solver, multi_class='multinomial')
143+
assert_raise_message(ValueError, msg, lr.fit, X, y)
144+
145+
# all solvers except 'liblinear'
146+
for solver in ['newton-cg', 'lbfgs']:
147+
msg = ("Solver %s supports only l2 penalties, got l1 penalty." %
148+
solver)
149+
lr = LR(solver=solver, penalty='l1')
150+
assert_raise_message(ValueError, msg, lr.fit, X, y)
151+
152+
msg = ("Solver %s supports only dual=False, got dual=True" %
153+
solver)
154+
lr = LR(solver=solver, dual=True)
155+
assert_raise_message(ValueError, msg, lr.fit, X, y)
156+
157+
129158
def test_multinomial_binary():
130159
# Test multinomial LR on a binary problem.
131160
target = (iris.target > 0).astype(np.intp)

0 commit comments

Comments
 (0)