|
5 | 5 | from scipy.sparse import csr_matrix |
6 | 6 |
|
7 | 7 | from sklearn.utils.testing import (assert_array_equal, assert_almost_equal, |
8 | | - assert_false, assert_raises, assert_equal) |
| 8 | + assert_false, assert_raises, assert_equal, |
| 9 | + assert_allclose, assert_greater) |
9 | 10 | from sklearn.feature_selection.mutual_info_ import ( |
10 | 11 | mutual_info_regression, mutual_info_classif, _compute_mi) |
11 | 12 |
|
@@ -158,8 +159,19 @@ def test_mutual_info_classif_mixed(): |
158 | 159 | y = ((0.5 * X[:, 0] + X[:, 2]) > 0.5).astype(int) |
159 | 160 | X[:, 2] = X[:, 2] > 0.5 |
160 | 161 |
|
161 | | - mi = mutual_info_classif(X, y, discrete_features=[2], random_state=0) |
| 162 | + mi = mutual_info_classif(X, y, discrete_features=[2], n_neighbors=3, |
| 163 | + random_state=0) |
162 | 164 | assert_array_equal(np.argsort(-mi), [2, 0, 1]) |
| 165 | + for n_neighbors in [5, 7, 9]: |
| 166 | + mi_nn = mutual_info_classif(X, y, discrete_features=[2], |
| 167 | + n_neighbors=n_neighbors, random_state=0) |
| 168 | + # Check that the continuous values have an higher MI with greater |
| 169 | + # n_neighbors |
| 170 | + assert_greater(mi_nn[0], mi[0]) |
| 171 | + assert_greater(mi_nn[1], mi[1]) |
| 172 | + # The n_neighbors should not have any effect on the discrete value |
| 173 | + # The MI should be the same |
| 174 | + assert_equal(mi_nn[2], mi[2]) |
163 | 175 |
|
164 | 176 |
|
165 | 177 | def test_mutual_info_options(): |
|
0 commit comments