Skip to content

Commit aaebee1

Browse files
glemaitreraghavrv
authored andcommitted
FIX Issue scikit-learn#8173 - pass n_neighbors in MI computation (scikit-learn#8181)
1 parent 4826883 commit aaebee1

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

sklearn/feature_selection/mutual_info_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def _estimate_mi(X, y, discrete_features='auto', discrete_target=False,
281281
y = scale(y, with_mean=False)
282282
y += 1e-10 * np.maximum(1, np.mean(np.abs(y))) * rng.randn(n_samples)
283283

284-
mi = [_compute_mi(x, y, discrete_feature, discrete_target) for
284+
mi = [_compute_mi(x, y, discrete_feature, discrete_target, n_neighbors) for
285285
x, discrete_feature in moves.zip(_iterate_columns(X), discrete_mask)]
286286

287287
return np.array(mi)

sklearn/feature_selection/tests/test_mutual_info.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from scipy.sparse import csr_matrix
66

77
from sklearn.utils.testing import (assert_array_equal, assert_almost_equal,
8-
assert_false, assert_raises, assert_equal)
8+
assert_false, assert_raises, assert_equal,
9+
assert_allclose, assert_greater)
910
from sklearn.feature_selection.mutual_info_ import (
1011
mutual_info_regression, mutual_info_classif, _compute_mi)
1112

@@ -158,8 +159,19 @@ def test_mutual_info_classif_mixed():
158159
y = ((0.5 * X[:, 0] + X[:, 2]) > 0.5).astype(int)
159160
X[:, 2] = X[:, 2] > 0.5
160161

161-
mi = mutual_info_classif(X, y, discrete_features=[2], random_state=0)
162+
mi = mutual_info_classif(X, y, discrete_features=[2], n_neighbors=3,
163+
random_state=0)
162164
assert_array_equal(np.argsort(-mi), [2, 0, 1])
165+
for n_neighbors in [5, 7, 9]:
166+
mi_nn = mutual_info_classif(X, y, discrete_features=[2],
167+
n_neighbors=n_neighbors, random_state=0)
168+
# Check that the continuous values have an higher MI with greater
169+
# n_neighbors
170+
assert_greater(mi_nn[0], mi[0])
171+
assert_greater(mi_nn[1], mi[1])
172+
# The n_neighbors should not have any effect on the discrete value
173+
# The MI should be the same
174+
assert_equal(mi_nn[2], mi[2])
163175

164176

165177
def test_mutual_info_options():

0 commit comments

Comments
 (0)