Skip to content

Commit 8f72c2a

Browse files
authored
ENH Adds n_features_in_ checking in gaussian_process (scikit-learn#18743)
1 parent b7a48a0 commit 8f72c2a

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

sklearn/gaussian_process/_gpc.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ..base import BaseEstimator, ClassifierMixin, clone
1515
from .kernels \
1616
import RBF, CompoundKernel, ConstantKernel as C
17-
from ..utils.validation import check_is_fitted, check_array
17+
from ..utils.validation import check_is_fitted
1818
from ..utils import check_random_state
1919
from ..utils.optimize import _check_optimize_result
2020
from ..preprocessing import LabelEncoder
@@ -689,9 +689,11 @@ def predict(self, X):
689689
check_is_fitted(self)
690690

691691
if self.kernel is None or self.kernel.requires_vector_input:
692-
X = check_array(X, ensure_2d=True, dtype="numeric")
692+
X = self._validate_data(X, ensure_2d=True, dtype="numeric",
693+
reset=False)
693694
else:
694-
X = check_array(X, ensure_2d=False, dtype=None)
695+
X = self._validate_data(X, ensure_2d=False, dtype=None,
696+
reset=False)
695697

696698
return self.base_estimator_.predict(X)
697699

@@ -717,9 +719,11 @@ def predict_proba(self, X):
717719
"one_vs_rest mode instead.")
718720

719721
if self.kernel is None or self.kernel.requires_vector_input:
720-
X = check_array(X, ensure_2d=True, dtype="numeric")
722+
X = self._validate_data(X, ensure_2d=True, dtype="numeric",
723+
reset=False)
721724
else:
722-
X = check_array(X, ensure_2d=False, dtype=None)
725+
X = self._validate_data(X, ensure_2d=False, dtype=None,
726+
reset=False)
723727

724728
return self.base_estimator_.predict_proba(X)
725729

sklearn/gaussian_process/_gpr.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from ..base import MultiOutputMixin
1616
from .kernels import RBF, ConstantKernel as C
1717
from ..utils import check_random_state
18-
from ..utils.validation import check_array
1918
from ..utils.optimize import _check_optimize_result
2019
from ..utils.validation import _deprecate_positional_args
2120

@@ -320,9 +319,11 @@ def predict(self, X, return_std=False, return_cov=False):
320319
"returning full covariance.")
321320

322321
if self.kernel is None or self.kernel.requires_vector_input:
323-
X = check_array(X, ensure_2d=True, dtype="numeric")
322+
X = self._validate_data(X, ensure_2d=True, dtype="numeric",
323+
reset=False)
324324
else:
325-
X = check_array(X, ensure_2d=False, dtype=None)
325+
X = self._validate_data(X, ensure_2d=False, dtype=None,
326+
reset=False)
326327

327328
if not hasattr(self, "X_train_"): # Unfitted;predict based on GP prior
328329
if self.kernel is None:

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,6 @@ def test_search_cv(estimator, check, request):
272272
'ensemble',
273273
'feature_extraction',
274274
'feature_selection',
275-
'gaussian_process',
276275
'isotonic',
277276
'linear_model',
278277
'manifold',

0 commit comments

Comments
 (0)