Skip to content

Commit 28c893d

Browse files
committed
Merge pull request scikit-learn#4488 from glouppe/tree-apply
[MRG + 1] Public `apply` method for decision trees
2 parents 58dfb3d + 29dddc7 commit 28c893d

File tree

3 files changed

+66
-2
lines changed

3 files changed

+66
-2
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ Bug fixes
2525
API changes summary
2626
-------------------
2727

28+
- :class:`tree.DecisionTreeClassifier` now exposes an ``apply`` method
29+
for retrieving the leaf indices samples are predicted as. By
30+
`Daniel Galvez`_ and `Gilles Louppe`_.
31+
2832
.. _changes_0_16:
2933

3034
0.16

sklearn/tree/tests/test_tree.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sklearn.utils.testing import assert_true
3030
from sklearn.utils.testing import raises
3131
from sklearn.utils.validation import check_random_state
32+
from sklearn.utils.validation import NotFittedError
3233

3334
from sklearn.tree import DecisionTreeClassifier
3435
from sklearn.tree import DecisionTreeRegressor
@@ -494,7 +495,7 @@ def test_error():
494495
for name, TreeEstimator in CLF_TREES.items():
495496
# predict before fit
496497
est = TreeEstimator()
497-
assert_raises(Exception, est.predict_proba, X)
498+
assert_raises(NotFittedError, est.predict_proba, X)
498499

499500
est.fit(X, y)
500501
X2 = [-2, -1, 1] # wrong feature shape for sample
@@ -527,7 +528,7 @@ def test_error():
527528

528529
# predict before fitting
529530
est = TreeEstimator()
530-
assert_raises(Exception, est.predict, T)
531+
assert_raises(NotFittedError, est.predict, T)
531532

532533
# predict on vector with different dims
533534
est.fit(X, y)
@@ -545,6 +546,10 @@ def test_error():
545546
clf.fit(X, y)
546547
assert_raises(ValueError, clf.predict, Xt)
547548

549+
# apply before fitting
550+
est = TreeEstimator()
551+
assert_raises(NotFittedError, est.apply, T)
552+
548553

549554
def test_min_samples_leaf():
550555
# Test if leaves contain more than leaf_count training examples
@@ -1208,6 +1213,8 @@ def check_explicit_sparse_zeros(tree, max_depth=3,
12081213
Xs = (X_test, X_sparse_test)
12091214
for X1, X2 in product(Xs, Xs):
12101215
assert_array_almost_equal(s.tree_.apply(X1), d.tree_.apply(X2))
1216+
assert_array_almost_equal(s.apply(X1), d.apply(X2))
1217+
assert_array_almost_equal(s.apply(X1), s.tree_.apply(X1))
12111218
assert_array_almost_equal(s.predict(X1), d.predict(X2))
12121219

12131220
if tree in CLF_TREES:
@@ -1266,3 +1273,29 @@ def check_min_weight_leaf_split_level(name):
12661273
def test_min_weight_leaf_split_level():
12671274
for name in ALL_TREES:
12681275
yield check_min_weight_leaf_split_level, name
1276+
1277+
1278+
def check_public_apply(name):
1279+
X_small32 = X_small.astype(tree._tree.DTYPE)
1280+
1281+
est = ALL_TREES[name]()
1282+
est.fit(X_small, y_small)
1283+
assert_array_equal(est.apply(X_small),
1284+
est.tree_.apply(X_small32))
1285+
1286+
1287+
def check_public_apply_sparse(name):
1288+
X_small32 = csr_matrix(X_small.astype(tree._tree.DTYPE))
1289+
1290+
est = ALL_TREES[name]()
1291+
est.fit(X_small, y_small)
1292+
assert_array_equal(est.apply(X_small),
1293+
est.tree_.apply(X_small32))
1294+
1295+
1296+
def test_public_apply():
1297+
for name in ALL_TREES:
1298+
yield (check_public_apply, name)
1299+
1300+
for name in SPARSE_TREES:
1301+
yield (check_public_apply_sparse, name)

sklearn/tree/tree.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,33 @@ def predict(self, X):
370370
else:
371371
return proba[:, :, 0]
372372

373+
def apply(self, X):
374+
"""
375+
Returns the index of the leaf that each sample is predicted as.
376+
377+
Parameters
378+
----------
379+
X : array_like or sparse matrix, shape = [n_samples, n_features]
380+
The input samples. Internally, it will be converted to
381+
``dtype=np.float32`` and if a sparse matrix is provided
382+
to a sparse ``csr_matrix``.
383+
384+
Returns
385+
-------
386+
X_leaves : array_like, shape = [n_samples,]
387+
For each datapoint x in X, return the index of the leaf x
388+
ends up in. Leaves are numbered within
389+
``[0; self.tree_.node_count)``, possibly with gaps in the
390+
numbering.
391+
"""
392+
if self.tree_ is None:
393+
raise NotFittedError("Estimator not fitted, "
394+
"call `fit` before `apply`.")
395+
396+
X = check_array(X, dtype=DTYPE, accept_sparse="csr")
397+
398+
return self.tree_.apply(X)
399+
373400
@property
374401
def feature_importances_(self):
375402
"""Return the feature importances.

0 commit comments

Comments
 (0)