Skip to content

Commit 154007b

Browse files
committed
fixing bug in linear_model.SGDClassifier for multi-class warm start
1 parent e3dac0d commit 154007b

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

sklearn/linear_model/stochastic_gradient.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, n_iter,
268268
The i'th class is considered the "positive" class.
269269
"""
270270
y_i, coef, intercept = _prepare_fit_binary(est, y, i)
271+
271272
assert y_i.shape[0] == y.shape[0] == sample_weight.shape[0]
272273
dataset, intercept_decay = _make_dataset(X, y_i, sample_weight)
273274

@@ -361,7 +362,7 @@ def _partial_fit(self, X, y, alpha, C,
361362
self.classes_, y_ind)
362363
sample_weight = self._validate_sample_weight(sample_weight, n_samples)
363364

364-
if self.coef_ is None:
365+
if self.coef_ is None or coef_init is not None:
365366
self._allocate_parameter_mem(n_classes, n_features,
366367
coef_init, intercept_init)
367368

sklearn/linear_model/tests/test_sgd.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def decision_function(self, X, *args, **kw):
104104

105105
class CommonTest(object):
106106

107-
def _test_warm_start(self, lr):
107+
def _test_warm_start(self, X, Y, lr):
108108
# Test that explicit warm restart...
109109
clf = self.factory(alpha=0.01, eta0=0.01, n_iter=5, shuffle=False,
110110
learning_rate=lr)
@@ -131,13 +131,16 @@ def _test_warm_start(self, lr):
131131
assert_array_almost_equal(clf3.coef_, clf2.coef_)
132132

133133
def test_warm_start_constant(self):
134-
self._test_warm_start("constant")
134+
self._test_warm_start(X, Y, "constant")
135135

136136
def test_warm_start_invscaling(self):
137-
self._test_warm_start("invscaling")
137+
self._test_warm_start(X, Y, "invscaling")
138138

139139
def test_warm_start_optimal(self):
140-
self._test_warm_start("optimal")
140+
self._test_warm_start(X, Y, "optimal")
141+
142+
def test_warm_start_multiclass(self):
143+
self._test_warm_start(X2, Y2, "optimal")
141144

142145
def test_multiple_fit(self):
143146
"""Test multiple calls of fit w/ different shaped inputs."""

0 commit comments

Comments
 (0)