Skip to content

Commit b1119bb

Browse files
committed
ENH friendlier message for calling predict before fit on SVMs
Fixes scikit-learn#3601.
1 parent 79646ff commit b1119bb

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

sklearn/svm/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,9 @@ def decision_function(self, X):
381381
return dec_func
382382

383383
def _validate_for_predict(self, X):
384+
if not hasattr(self, "support_"):
385+
raise ValueError("this %s has not been fitted yet"
386+
% type(self).__name__)
384387
X = check_array(X, accept_sparse='csr', dtype=np.float64, order="C")
385388
if self._sparse and not sp.isspmatrix(X):
386389
X = sp.csr_matrix(X)

sklearn/svm/tests/test_svm.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sklearn.utils import check_random_state
1717
from sklearn.utils import ConvergenceWarning
1818
from sklearn.utils.testing import assert_greater, assert_in, assert_less
19-
from sklearn.utils.testing import assert_warns
19+
from sklearn.utils.testing import assert_raises_regexp, assert_warns
2020

2121

2222
# toy sample
@@ -664,6 +664,18 @@ def test_timeout():
664664
assert_warns(ConvergenceWarning, a.fit, X, Y)
665665

666666

667+
def test_unfitted():
668+
X = "foo!" # input validation not required when SVM not fitted
669+
670+
clf = svm.SVC()
671+
assert_raises_regexp(Exception, r".*\bSVC\b.*\bnot\b.*\bfitted\b",
672+
clf.predict, X)
673+
674+
clf = svm.NuSVR()
675+
assert_raises_regexp(Exception, r".*\bNuSVR\b.*\bnot\b.*\bfitted\b",
676+
clf.predict, X)
677+
678+
667679
def test_consistent_proba():
668680
a = svm.SVC(probability=True, max_iter=1, random_state=0)
669681
proba_1 = a.fit(X, Y).predict_proba(X)

0 commit comments

Comments
 (0)