|
11 | 11 | from sklearn.datasets import load_iris, make_friedman1 |
12 | 12 | from sklearn.metrics import zero_one_loss |
13 | 13 | from sklearn.svm import SVC, SVR |
| 14 | +from sklearn.ensemble import RandomForestClassifier |
14 | 15 |
|
15 | 16 | from sklearn.utils import check_random_state |
16 | 17 | from sklearn.utils.testing import ignore_warnings |
@@ -69,6 +70,25 @@ def test_rfe_set_params(): |
69 | 70 | assert_array_equal(y_pred, y_pred2) |
70 | 71 |
|
71 | 72 |
|
| 73 | +def test_rfe_features_importance(): |
| 74 | + generator = check_random_state(0) |
| 75 | + iris = load_iris() |
| 76 | + X = np.c_[iris.data, generator.normal(size=(len(iris.data), 6))] |
| 77 | + y = iris.target |
| 78 | + |
| 79 | + clf = RandomForestClassifier(n_estimators=10, n_jobs=1) |
| 80 | + rfe = RFE(estimator=clf, n_features_to_select=4, step=0.1) |
| 81 | + rfe.fit(X, y) |
| 82 | + assert_equal(len(rfe.ranking_), X.shape[1]) |
| 83 | + |
| 84 | + clf_svc = SVC(kernel="linear") |
| 85 | + rfe_svc = RFE(estimator=clf_svc, n_features_to_select=4, step=0.1) |
| 86 | + rfe_svc.fit(X, y) |
| 87 | + |
| 88 | + # Check if the supports are equal |
| 89 | + diff_support = rfe.get_support() == rfe_svc.get_support() |
| 90 | + assert_true(sum(diff_support) == len(diff_support)) |
| 91 | + |
72 | 92 | def test_rfe_deprecation_estimator_params(): |
73 | 93 | deprecation_message = ("The parameter 'estimator_params' is deprecated as " |
74 | 94 | "of version 0.16 and will be removed in 0.18. The " |
|
0 commit comments