Skip to content

Commit 719ca98

Browse files
committed
FIX SVR complaining about a single class in the input
Decoupled SVC and SVR input validation logic. Fixes scikit-learn#1896. BaseLibSVM could still do with some more refactoring to move subclass-specific code out of its methods.
1 parent 10ee464 commit 719ca98

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

sklearn/svm/base.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -131,19 +131,7 @@ def fit(self, X, y, sample_weight=None):
131131
"by not using the ``sparse`` parameter")
132132

133133
X = atleast2d_or_csr(X, dtype=np.float64, order='C')
134-
135-
if self._impl in ['c_svc', 'nu_svc']:
136-
# classification
137-
self.classes_, y = unique(y, return_inverse=True)
138-
self.class_weight_ = compute_class_weight(self.class_weight,
139-
self.classes_, y)
140-
else:
141-
self.class_weight_ = np.empty(0)
142-
if self._impl != "one_class" and len(np.unique(y)) < 2:
143-
raise ValueError("The number of classes has to be greater than"
144-
" one.")
145-
146-
y = np.asarray(y, dtype=np.float64, order='C')
134+
y = self._validate_targets(y)
147135

148136
sample_weight = np.asarray([]
149137
if sample_weight is None
@@ -190,6 +178,16 @@ def fit(self, X, y, sample_weight=None):
190178
self.intercept_ *= -1
191179
return self
192180

181+
def _validate_targets(self, y):
182+
"""Validation of y and class_weight.
183+
184+
Default implementation for SVR and one-class; overridden in BaseSVC.
185+
"""
186+
# XXX this is ugly.
187+
# Regression models should not have a class_weight_ attribute.
188+
self.class_weight_ = np.empty(0)
189+
return np.asarray(y, dtype=np.float64, order='C')
190+
193191
def _warn_from_fit_status(self):
194192
assert self.fit_status_ in (0, 1)
195193
if self.fit_status_ == 1:
@@ -434,6 +432,16 @@ def coef_(self):
434432
class BaseSVC(BaseLibSVM, ClassifierMixin):
435433
"""ABC for LibSVM-based classifiers."""
436434

435+
def _validate_targets(self, y):
436+
self.classes_, y = unique(y, return_inverse=True)
437+
self.class_weight_ = compute_class_weight(self.class_weight,
438+
self.classes_, y)
439+
if len(np.unique(y)) < 2:
440+
raise ValueError(
441+
"The number of classes has to be greater than one")
442+
443+
return np.asarray(y, dtype=np.float64, order='C')
444+
437445
def predict(self, X):
438446
"""Perform classification on samples in X.
439447

sklearn/svm/tests/test_svm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,10 @@ def test_svr():
175175
clf.fit(diabetes.data, diabetes.target)
176176
assert_greater(clf.score(diabetes.data, diabetes.target), 0.02)
177177

178+
# non-regression test; previously, BaseLibSVM would check that
179+
# len(np.unique(y)) < 2, which must only be done for SVC
180+
svm.SVR().fit(diabetes.data, np.ones(len(diabetes.data)))
181+
178182

179183
def test_svr_errors():
180184
X = [[0.0], [1.0]]

0 commit comments

Comments
 (0)