Skip to content

Commit 1c07dec

Browse files
committed
Merge pull request scikit-learn#4894 from tw991/rf
fix dtype transform problem in KNN and RandomForest
2 parents 9f60b18 + 9f06156 commit 1c07dec

File tree

4 files changed

+26
-2
lines changed

4 files changed

+26
-2
lines changed

sklearn/ensemble/forest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,10 +414,12 @@ def _validate_y_class_weight(self, y):
414414
self.classes_ = []
415415
self.n_classes_ = []
416416

417+
y_store_unique_indices = np.zeros(y.shape, dtype=np.int)
417418
for k in range(self.n_outputs_):
418-
classes_k, y[:, k] = np.unique(y[:, k], return_inverse=True)
419+
classes_k, y_store_unique_indices[:, k] = np.unique(y[:, k], return_inverse=True)
419420
self.classes_.append(classes_k)
420421
self.n_classes_.append(classes_k.shape[0])
422+
y = y_store_unique_indices
421423

422424
if self.class_weight is not None:
423425
valid_presets = ('auto', 'balanced', 'balanced_subsample', 'subsample', 'auto')

sklearn/ensemble/tests/test_forest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,3 +971,13 @@ def test_warm_start_oob():
971971
yield check_warm_start_oob, name
972972
for name in FOREST_REGRESSORS:
973973
yield check_warm_start_oob, name
974+
975+
976+
def test_dtype_convert():
977+
classifier = RandomForestClassifier()
978+
CLASSES = 15
979+
X = np.eye(CLASSES)
980+
y = [ch for ch in 'ABCDEFGHIJKLMNOPQRSTU'[:CLASSES]]
981+
982+
result = classifier.fit(X, y).predict(X)
983+
assert_array_equal(result, y)

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,3 +1082,13 @@ def test_include_self_neighbors_graph():
10821082
X, 5.0, include_self=False).A
10831083
assert_array_equal(rng, [[1., 1.], [1., 1.]])
10841084
assert_array_equal(rng_not_self, [[0., 1.], [1., 0.]])
1085+
1086+
1087+
def test_dtype_convert():
1088+
classifier = neighbors.KNeighborsClassifier(n_neighbors=1)
1089+
CLASSES = 15
1090+
X = np.eye(CLASSES)
1091+
y = [ch for ch in 'ABCDEFGHIJKLMNOPQRSTU'[:CLASSES]]
1092+
1093+
result = classifier.fit(X, y).predict(X)
1094+
assert_array_equal(result, y)

sklearn/tree/tree.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,12 @@ def fit(self, X, y, sample_weight=None, check_input=True):
166166
if self.class_weight is not None:
167167
y_original = np.copy(y)
168168

169+
y_store_unique_indices = np.zeros(y.shape, dtype=np.int)
169170
for k in range(self.n_outputs_):
170-
classes_k, y[:, k] = np.unique(y[:, k], return_inverse=True)
171+
classes_k, y_store_unique_indices[:, k] = np.unique(y[:, k], return_inverse=True)
171172
self.classes_.append(classes_k)
172173
self.n_classes_.append(classes_k.shape[0])
174+
y = y_store_unique_indices
173175

174176
if self.class_weight is not None:
175177
expanded_class_weight = compute_sample_weight(

0 commit comments

Comments
 (0)