Skip to content

Commit d89c215

Browse files
committed
Add tags to classifiers and regressors to identify them as such.
1 parent e2dfd23 commit d89c215

File tree

16 files changed

+261
-61
lines changed

16 files changed

+261
-61
lines changed

doc/developers/index.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,19 @@ take arguments ``X, y``, even if y is not used. Similarly, for ``score`` to be
883883
usable, the last step of the pipeline needs to have a ``score`` function that
884884
accepts an optional ``y``.
885885

886+
Estimator types
887+
---------------
888+
Some common functionality depends on the kind of estimator passed.
889+
For example, cross-validation in :class:`grid_search.GridSearchCV` and
890+
:func:`cross_validation.cross_val_score` defaults to being stratified when used
891+
on a classifier, but not otherwise. Similarly, scorers for average precision
892+
that take a continuous prediction need to call ``decision_function`` for classifiers,
893+
but ``predict`` for regressors. This distinction between classifiers and regressors
894+
is implemented using the ``_estimator_type`` attribute, which takes a string value.
895+
It should be ``"classifier"`` for classifiers and ``"regressor"`` for regressors,
896+
to work as expected. Inheriting from ``ClassifierMixin`` or ``RegressorMixin`` will
897+
set the attribute automatically.
898+
886899
Working notes
887900
-------------
888901

sklearn/base.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -244,14 +244,14 @@ def set_params(self, **params):
244244
if len(split) > 1:
245245
# nested objects case
246246
name, sub_name = split
247-
if not name in valid_params:
247+
if name not in valid_params:
248248
raise ValueError('Invalid parameter %s for estimator %s' %
249249
(name, self))
250250
sub_object = valid_params[name]
251251
sub_object.set_params(**{sub_name: value})
252252
else:
253253
# simple objects case
254-
if not key in valid_params:
254+
if key not in valid_params:
255255
raise ValueError('Invalid parameter %s ' 'for estimator %s'
256256
% (key, self.__class__.__name__))
257257
setattr(self, key, value)
@@ -266,6 +266,7 @@ def __repr__(self):
266266
###############################################################################
267267
class ClassifierMixin(object):
268268
"""Mixin class for all classifiers in scikit-learn."""
269+
_estimator_type = "classifier"
269270

270271
def score(self, X, y, sample_weight=None):
271272
"""Returns the mean accuracy on the given test data and labels.
@@ -298,6 +299,7 @@ def score(self, X, y, sample_weight=None):
298299
###############################################################################
299300
class RegressorMixin(object):
300301
"""Mixin class for all regression estimators in scikit-learn."""
302+
_estimator_type = "regressor"
301303

302304
def score(self, X, y, sample_weight=None):
303305
"""Returns the coefficient of determination R^2 of the prediction.
@@ -331,6 +333,8 @@ def score(self, X, y, sample_weight=None):
331333
###############################################################################
332334
class ClusterMixin(object):
333335
"""Mixin class for all cluster estimators in scikit-learn."""
336+
_estimator_type = "clusterer"
337+
334338
def fit_predict(self, X, y=None):
335339
"""Performs clustering on X and returns cluster labels.
336340
@@ -443,20 +447,12 @@ class MetaEstimatorMixin(object):
443447

444448

445449
###############################################################################
446-
# XXX: Temporary solution to figure out if an estimator is a classifier
447-
448-
def _get_sub_estimator(estimator):
449-
"""Returns the final estimator if there is any."""
450-
if hasattr(estimator, 'estimator'):
451-
# GridSearchCV and other CV-tuned estimators
452-
return _get_sub_estimator(estimator.estimator)
453-
if hasattr(estimator, 'steps'):
454-
# Pipeline
455-
return _get_sub_estimator(estimator.steps[-1][1])
456-
return estimator
457-
458450

459451
def is_classifier(estimator):
460452
"""Returns True if the given estimator is (probably) a classifier."""
461-
estimator = _get_sub_estimator(estimator)
462-
return isinstance(estimator, ClassifierMixin)
453+
return getattr(estimator, "_estimator_type", None) == "classifier"
454+
455+
456+
def is_regressor(estimator):
457+
"""Returns True if the given estimator is (probably) a regressor."""
458+
return getattr(estimator, "_estimator_type", None) == "regressor"

sklearn/ensemble/gradient_boosting.py

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from ..base import ClassifierMixin
3636
from ..base import RegressorMixin
3737
from ..utils import check_random_state, check_array, check_X_y, column_or_1d
38-
from ..utils import check_consistent_length
38+
from ..utils import check_consistent_length, deprecated
3939
from ..utils.extmath import logsumexp
4040
from ..utils.fixes import expit, bincount
4141
from ..utils.stats import _weighted_percentile
@@ -438,7 +438,7 @@ class ClassificationLossFunction(six.with_metaclass(ABCMeta, LossFunction)):
438438
def _score_to_proba(self, score):
439439
"""Template method to convert scores to probabilities.
440440
441-
If the loss does not support probabilites raises AttributeError.
441+
the does not support probabilites raises AttributeError.
442442
"""
443443
raise TypeError('%s does not support predict_proba' % type(self).__name__)
444444

@@ -1044,9 +1044,10 @@ def _fit_stages(self, X, y, y_pred, sample_weight, random_state,
10441044
self.train_score_[i] = loss_(y[sample_mask],
10451045
y_pred[sample_mask],
10461046
sample_weight[sample_mask])
1047-
self.oob_improvement_[i] = (old_oob_score -
1048-
loss_(y[~sample_mask], y_pred[~sample_mask],
1049-
sample_weight[~sample_mask]))
1047+
self.oob_improvement_[i] = (
1048+
old_oob_score - loss_(y[~sample_mask],
1049+
y_pred[~sample_mask],
1050+
sample_weight[~sample_mask]))
10501051
else:
10511052
# no need to fancy index w/ no subsampling
10521053
self.train_score_[i] = loss_(y, y_pred, sample_weight)
@@ -1082,6 +1083,7 @@ def _decision_function(self, X):
10821083
predict_stages(self.estimators_, X, self.learning_rate, score)
10831084
return score
10841085

1086+
@deprecated(" and will be removed in 0.19")
10851087
def decision_function(self, X):
10861088
"""Compute the decision function of ``X``.
10871089
@@ -1104,7 +1106,7 @@ def decision_function(self, X):
11041106
return score.ravel()
11051107
return score
11061108

1107-
def staged_decision_function(self, X):
1109+
def _staged_decision_function(self, X):
11081110
"""Compute decision function of ``X`` for each iteration.
11091111
11101112
This method allows monitoring (i.e. determine error on testing set)
@@ -1129,6 +1131,30 @@ def staged_decision_function(self, X):
11291131
predict_stage(self.estimators_, i, X, self.learning_rate, score)
11301132
yield score.copy()
11311133

1134+
@deprecated(" and will be removed in 0.19")
1135+
def staged_decision_function(self, X):
1136+
"""Compute decision function of ``X`` for each iteration.
1137+
1138+
This method allows monitoring (i.e. determine error on testing set)
1139+
after each stage.
1140+
1141+
Parameters
1142+
----------
1143+
X : array-like of shape = [n_samples, n_features]
1144+
The input samples.
1145+
1146+
Returns
1147+
-------
1148+
score : generator of array, shape = [n_samples, k]
1149+
The decision function of the input samples. The order of the
1150+
classes corresponds to that in the attribute `classes_`.
1151+
Regression and binary classification are special cases with
1152+
``k == 1``, otherwise ``k==n_classes``.
1153+
"""
1154+
for dec in self._staged_decision_function(X):
1155+
# no yield from in Python2.X
1156+
yield dec
1157+
11321158
@property
11331159
def feature_importances_(self):
11341160
"""Return the feature importances (the higher, the more important the
@@ -1315,6 +1341,51 @@ def _validate_y(self, y):
13151341
self.n_classes_ = len(self.classes_)
13161342
return y
13171343

1344+
def decision_function(self, X):
1345+
"""Compute the decision function of ``X``.
1346+
1347+
Parameters
1348+
----------
1349+
X : array-like of shape = [n_samples, n_features]
1350+
The input samples.
1351+
1352+
Returns
1353+
-------
1354+
score : array, shape = [n_samples, n_classes] or [n_samples]
1355+
The decision function of the input samples. The order of the
1356+
classes corresponds to that in the attribute `classes_`.
1357+
Regression and binary classification produce an array of shape
1358+
[n_samples].
1359+
"""
1360+
X = check_array(X, dtype=DTYPE, order="C")
1361+
score = self._decision_function(X)
1362+
if score.shape[1] == 1:
1363+
return score.ravel()
1364+
return score
1365+
1366+
def staged_decision_function(self, X):
1367+
"""Compute decision function of ``X`` for each iteration.
1368+
1369+
This method allows monitoring (i.e. determine error on testing set)
1370+
after each stage.
1371+
1372+
Parameters
1373+
----------
1374+
X : array-like of shape = [n_samples, n_features]
1375+
The input samples.
1376+
1377+
Returns
1378+
-------
1379+
score : generator of array, shape = [n_samples, k]
1380+
The decision function of the input samples. The order of the
1381+
classes corresponds to that in the attribute `classes_`.
1382+
Regression and binary classification are special cases with
1383+
``k == 1``, otherwise ``k==n_classes``.
1384+
"""
1385+
for dec in self._staged_decision_function(X):
1386+
# no yield from in Python2.X
1387+
yield dec
1388+
13181389
def predict(self, X):
13191390
"""Predict class for X.
13201391
@@ -1348,7 +1419,7 @@ def staged_predict(self, X):
13481419
y : generator of array of shape = [n_samples]
13491420
The predicted value of the input samples.
13501421
"""
1351-
for score in self.staged_decision_function(X):
1422+
for score in self._staged_decision_function(X):
13521423
decisions = self.loss_._score_to_decision(score)
13531424
yield self.classes_.take(decisions, axis=0)
13541425

@@ -1419,7 +1490,7 @@ def staged_predict_proba(self, X):
14191490
The predicted value of the input samples.
14201491
"""
14211492
try:
1422-
for score in self.staged_decision_function(X):
1493+
for score in self._staged_decision_function(X):
14231494
yield self.loss_._score_to_proba(score)
14241495
except NotFittedError:
14251496
raise
@@ -1594,7 +1665,8 @@ def predict(self, X):
15941665
y : array of shape = [n_samples]
15951666
The predicted values.
15961667
"""
1597-
return self.decision_function(X).ravel()
1668+
X = check_array(X, dtype=DTYPE, order="C")
1669+
return self._decision_function(X).ravel()
15981670

15991671
def staged_predict(self, X):
16001672
"""Predict regression target at each stage for X.
@@ -1612,5 +1684,5 @@ def staged_predict(self, X):
16121684
y : generator of array of shape = [n_samples]
16131685
The predicted value of the input samples.
16141686
"""
1615-
for y in self.staged_decision_function(X):
1687+
for y in self._staged_decision_function(X):
16161688
yield y.ravel()

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Testing for the gradient boosting module (sklearn.ensemble.gradient_boosting).
33
"""
4-
4+
import warnings
55
import numpy as np
66

77
from sklearn import datasets
@@ -171,8 +171,9 @@ def test_boston():
171171
for loss in ("ls", "lad", "huber"):
172172
for subsample in (1.0, 0.5):
173173
last_y_pred = None
174-
for i, sample_weight in enumerate((None, np.ones(len(boston.target)),
175-
2 * np.ones(len(boston.target)))):
174+
for i, sample_weight in enumerate(
175+
(None, np.ones(len(boston.target)),
176+
2 * np.ones(len(boston.target)))):
176177
clf = GradientBoostingRegressor(n_estimators=100, loss=loss,
177178
max_depth=4, subsample=subsample,
178179
min_samples_split=1,
@@ -343,6 +344,7 @@ def test_check_max_features():
343344
max_features=-0.1)
344345
assert_raises(ValueError, clf.fit, X, y)
345346

347+
346348
def test_max_feature_regression():
347349
# Test to make sure random state is set properly.
348350
X, y = datasets.make_hastie_10_2(n_samples=12000, random_state=1)
@@ -455,7 +457,8 @@ def test_staged_functions_defensive():
455457
if staged_func is None:
456458
# regressor has no staged_predict_proba
457459
continue
458-
staged_result = list(staged_func(X))
460+
with warnings.catch_warnings(record=True):
461+
staged_result = list(staged_func(X))
459462
staged_result[1][:] = 0
460463
assert_true(np.all(staged_result[0] != 0))
461464

@@ -843,7 +846,7 @@ def test_complete_classification():
843846
k = 4
844847

845848
est = GradientBoostingClassifier(n_estimators=20, max_depth=None,
846-
random_state=1, max_leaf_nodes=k+1)
849+
random_state=1, max_leaf_nodes=k + 1)
847850
est.fit(X, y)
848851

849852
tree = est.estimators_[0, 0].tree_
@@ -858,7 +861,7 @@ def test_complete_regression():
858861
k = 4
859862

860863
est = GradientBoostingRegressor(n_estimators=20, max_depth=None,
861-
random_state=1, max_leaf_nodes=k+1)
864+
random_state=1, max_leaf_nodes=k + 1)
862865
est.fit(boston.data, boston.target)
863866

864867
tree = est.estimators_[-1, 0].tree_
@@ -971,8 +974,7 @@ def test_non_uniform_weights_toy_edge_case_reg():
971974
X = [[1, 0],
972975
[1, 0],
973976
[1, 0],
974-
[0, 1],
975-
]
977+
[0, 1]]
976978
y = [0, 0, 1, 0]
977979
# ignore the first 2 training samples by setting their weight to 0
978980
sample_weight = [0, 0, 1, 1]
@@ -1002,8 +1004,7 @@ def test_non_uniform_weights_toy_edge_case_clf():
10021004
X = [[1, 0],
10031005
[1, 0],
10041006
[1, 0],
1005-
[0, 1],
1006-
]
1007+
[0, 1]]
10071008
y = [0, 0, 1, 0]
10081009
# ignore the first 2 training samples by setting their weight to 0
10091010
sample_weight = [0, 0, 1, 1]

sklearn/grid_search.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,10 @@ def __init__(self, estimator, scoring=None,
331331
self.pre_dispatch = pre_dispatch
332332
self.error_score = error_score
333333

334+
@property
335+
def _estimator_type(self):
336+
return self.estimator._estimator_type
337+
334338
def score(self, X, y=None):
335339
"""Returns the score on the given data, if the estimator has been refit
336340

sklearn/linear_model/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ..externals import six
2525
from ..externals.joblib import Parallel, delayed
2626
from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
27-
from ..utils import as_float_array, check_array, check_X_y
27+
from ..utils import as_float_array, check_array, check_X_y, deprecated
2828
from ..utils.extmath import safe_sparse_dot
2929
from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale
3030
from ..utils.fixes import sparse_lsqr
@@ -119,6 +119,7 @@ class LinearModel(six.with_metaclass(ABCMeta, BaseEstimator)):
119119
def fit(self, X, y):
120120
"""Fit model."""
121121

122+
@deprecated(" and will be removed in 0.19.")
122123
def decision_function(self, X):
123124
"""Decision function of the linear model.
124125
@@ -132,6 +133,9 @@ def decision_function(self, X):
132133
C : array, shape = (n_samples,)
133134
Returns predicted values.
134135
"""
136+
return self._decision_function(X)
137+
138+
def _decision_function(self, X):
135139
check_is_fitted(self, "coef_")
136140

137141
X = check_array(X, accept_sparse=['csr', 'csc', 'coo'])
@@ -151,7 +155,7 @@ def predict(self, X):
151155
C : array, shape = (n_samples,)
152156
Returns predicted values.
153157
"""
154-
return self.decision_function(X)
158+
return self._decision_function(X)
155159

156160
_center_data = staticmethod(center_data)
157161

0 commit comments

Comments
 (0)