Skip to content

Commit c4b650f

Browse files
committed
Merge pull request scikit-learn#4496 from vmichel/rfe_feature_importances
[MRG+1] Fix RFE
2 parents afd2b53 + 0f0e58d commit c4b650f

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

sklearn/feature_selection/rfe.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,21 @@ def fit(self, X, y):
160160

161161
estimator.fit(X[:, features], y)
162162

163-
if estimator.coef_.ndim > 1:
164-
ranks = np.argsort(safe_sqr(estimator.coef_).sum(axis=0))
163+
# Get coefs
164+
if hasattr(estimator, 'coef_'):
165+
coefs = estimator.coef_
166+
elif hasattr(estimator, 'feature_importances_'):
167+
coefs = estimator.feature_importances_
165168
else:
166-
ranks = np.argsort(safe_sqr(estimator.coef_))
169+
raise RuntimeError('The classifier does not expose '
170+
'"coef_" or "feature_importances_" '
171+
'attributes')
172+
173+
# Get ranks
174+
if coefs.ndim > 1:
175+
ranks = np.argsort(safe_sqr(coefs).sum(axis=0))
176+
else:
177+
ranks = np.argsort(safe_sqr(coefs))
167178

168179
# for sparse case ranks is matrix
169180
ranks = np.ravel(ranks)

sklearn/feature_selection/tests/test_rfe.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sklearn.datasets import load_iris, make_friedman1
1212
from sklearn.metrics import zero_one_loss
1313
from sklearn.svm import SVC, SVR
14+
from sklearn.ensemble import RandomForestClassifier
1415

1516
from sklearn.utils import check_random_state
1617
from sklearn.utils.testing import ignore_warnings
@@ -69,6 +70,25 @@ def test_rfe_set_params():
6970
assert_array_equal(y_pred, y_pred2)
7071

7172

73+
def test_rfe_features_importance():
74+
generator = check_random_state(0)
75+
iris = load_iris()
76+
X = np.c_[iris.data, generator.normal(size=(len(iris.data), 6))]
77+
y = iris.target
78+
79+
clf = RandomForestClassifier(n_estimators=10, n_jobs=1)
80+
rfe = RFE(estimator=clf, n_features_to_select=4, step=0.1)
81+
rfe.fit(X, y)
82+
assert_equal(len(rfe.ranking_), X.shape[1])
83+
84+
clf_svc = SVC(kernel="linear")
85+
rfe_svc = RFE(estimator=clf_svc, n_features_to_select=4, step=0.1)
86+
rfe_svc.fit(X, y)
87+
88+
# Check if the supports are equal
89+
diff_support = rfe.get_support() == rfe_svc.get_support()
90+
assert_true(sum(diff_support) == len(diff_support))
91+
7292
def test_rfe_deprecation_estimator_params():
7393
deprecation_message = ("The parameter 'estimator_params' is deprecated as "
7494
"of version 0.16 and will be removed in 0.18. The "

0 commit comments

Comments
 (0)