Skip to content

Commit fc9d7be

Browse files
committed
FIX use random_state in LogisticRegression
1 parent 4eda9e6 commit fc9d7be

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

sklearn/linear_model/logistic.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ..preprocessing import LabelEncoder, LabelBinarizer
2121
from ..svm.base import _fit_liblinear
2222
from ..utils import check_array, check_consistent_length, compute_class_weight
23+
from ..utils import check_random_state
2324
from ..utils.extmath import (logsumexp, log_logistic, safe_sparse_dot,
2425
squared_norm)
2526
from ..utils.optimize import newton_cg
@@ -417,7 +418,8 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
417418
max_iter=100, tol=1e-4, verbose=0,
418419
solver='lbfgs', coef=None, copy=True,
419420
class_weight=None, dual=False, penalty='l2',
420-
intercept_scaling=1., multi_class='ovr'):
421+
intercept_scaling=1., multi_class='ovr',
422+
random_state=None):
421423
"""Compute a Logistic Regression model for a list of regularization
422424
parameters.
423425
@@ -502,8 +504,12 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
502504
Multiclass option can be either 'ovr' or 'multinomial'. If the option
503505
chosen is 'ovr', then a binary problem is fit for each label. Else
504506
the loss minimised is the multinomial loss fit across
505-
the entire probability distribution. Works only for the 'lbfgs'
506-
solver.
507+
the entire probability distribution. Works only for the 'lbfgs' and
508+
'newton-cg' solvers.
509+
510+
random_state : int seed, RandomState instance, or None (default)
511+
The seed of the pseudo random number generator to use when
512+
shuffling the data.
507513
508514
Returns
509515
-------
@@ -531,6 +537,7 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
531537
_, n_features = X.shape
532538
check_consistent_length(X, y)
533539
classes = np.unique(y)
540+
random_state = check_random_state(random_state)
534541

535542
if pos_class is None and multi_class != 'multinomial':
536543
if (classes.size > 2):
@@ -659,7 +666,7 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
659666
elif solver == 'liblinear':
660667
coef_, intercept_, _, = _fit_liblinear(
661668
X, y, C, fit_intercept, intercept_scaling, class_weight,
662-
penalty, dual, verbose, max_iter, tol,
669+
penalty, dual, verbose, max_iter, tol, random_state
663670
)
664671
if fit_intercept:
665672
w0 = np.concatenate([coef_.ravel(), intercept_])
@@ -1029,7 +1036,7 @@ def fit(self, X, y):
10291036
self.coef_, self.intercept_, self.n_iter_ = _fit_liblinear(
10301037
X, y, self.C, self.fit_intercept, self.intercept_scaling,
10311038
self.class_weight, self.penalty, self.dual, self.verbose,
1032-
self.max_iter, self.tol
1039+
self.max_iter, self.tol, self.random_state
10331040
)
10341041
return self
10351042

sklearn/linear_model/tests/test_logistic.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,13 +266,22 @@ def test_consistency_path():
266266
assert_array_almost_equal(lr_coef, coefs[0], decimal=4)
267267

268268

269-
def test_liblinear_random_state():
269+
def test_liblinear_dual_random_state():
270+
# random_state is relevant for liblinear solver only if dual=True
270271
X, y = make_classification(n_samples=20)
271-
lr1 = LogisticRegression(random_state=0)
272+
lr1 = LogisticRegression(random_state=0, dual=True, max_iter=1, tol=1e-15)
272273
lr1.fit(X, y)
273-
lr2 = LogisticRegression(random_state=0)
274+
lr2 = LogisticRegression(random_state=0, dual=True, max_iter=1, tol=1e-15)
274275
lr2.fit(X, y)
276+
lr3 = LogisticRegression(random_state=8, dual=True, max_iter=1, tol=1e-15)
277+
lr3.fit(X, y)
278+
279+
# same result for same random state
275280
assert_array_almost_equal(lr1.coef_, lr2.coef_)
281+
# different results for different random states
282+
msg = "Arrays are not almost equal to 6 decimals"
283+
assert_raise_message(AssertionError, msg,
284+
assert_array_almost_equal, lr1.coef_, lr3.coef_)
276285

277286

278287
def test_logistic_loss_and_grad():

0 commit comments

Comments
 (0)