Skip to content

Commit 31f412a

Browse files
committed
COSMIT decouple regression and classification in SVMs
1 parent bae2d1b commit 31f412a

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

sklearn/svm/base.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def fit(self, X, y, sample_weight=None):
102102
and n_features is the number of features.
103103
104104
y : array-like, shape = [n_samples]
105-
Target values (integers in classification, real numbers in
105+
Target values (class labels in classification, real numbers in
106106
regression)
107107
108108
sample_weight : array-like, shape = [n_samples], optional
@@ -259,11 +259,7 @@ def _sparse_fit(self, X, y, sample_weight, solver_type, kernel):
259259
(n_class, n_SV))
260260

261261
def predict(self, X):
262-
"""Perform classification or regression samples in X.
263-
264-
For a classification model, the predicted class for each
265-
sample in X is returned. For a regression model, the function
266-
value of X calculated is returned.
262+
"""Perform regression on samples in X.
267263
268264
For an one-class model, +1 or -1 is returned.
269265
@@ -277,11 +273,7 @@ def predict(self, X):
277273
"""
278274
X = self._validate_for_predict(X)
279275
predict = self._sparse_predict if self._sparse else self._dense_predict
280-
y = predict(X)
281-
if self.impl in ['c_svc', 'nu_svc']:
282-
# classification
283-
y = self.classes_.take(y.astype(np.int))
284-
return y
276+
return predict(X)
285277

286278
def _dense_predict(self, X):
287279
n_samples, n_features = X.shape
@@ -445,6 +437,23 @@ def coef_(self):
445437
class BaseSVC(BaseLibSVM, ClassifierMixin):
446438
"""ABC for LibSVM-based classifiers."""
447439

440+
def predict(self, X):
441+
"""Perform classification on samples in X.
442+
443+
For an one-class model, +1 or -1 is returned.
444+
445+
Parameters
446+
----------
447+
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
448+
449+
Returns
450+
-------
451+
y_pred : array, shape = [n_samples]
452+
Class labels for samples in X.
453+
"""
454+
y = super(BaseSVC, self).predict(X)
455+
return self.classes_.take(y.astype(np.int))
456+
448457
def predict_proba(self, X):
449458
"""Compute probabilities of possible outcomes for samples in X.
450459

0 commit comments

Comments
 (0)