Skip to content

Commit f656c37

Browse files
amy12xxrththomasjpfanjnothman
authored
FIX Fixes ClassifierChain issue when a tuple is passed to order (scikit-learn#18124)
* fix error tuple passed to order * fix linting * Update sklearn/tests/test_multioutput.py Co-authored-by: Roman Yurchak <[email protected]> * updated test and whatsnew * doc fix * Update doc/whats_new/v0.24.rst Co-authored-by: Roman Yurchak <[email protected]> * Update sklearn/tests/test_multioutput.py Co-authored-by: Thomas J. Fan <[email protected]> * Update sklearn/tests/test_multioutput.py Co-authored-by: Thomas J. Fan <[email protected]> * code review fix * code review fix * added test * code review fix * Update sklearn/multioutput.py Co-authored-by: Joel Nothman <[email protected]> Co-authored-by: Roman Yurchak <[email protected]> Co-authored-by: Thomas J. Fan <[email protected]> Co-authored-by: Joel Nothman <[email protected]>
1 parent e0abd26 commit f656c37

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

doc/whats_new/v0.24.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ Changelog
260260
will be removed in 0.28. :pr:`17662` by
261261
:user:`Joshua Newton <joshuacwnewton>`.
262262

263-
- |Efficiency| Fixed :issue:`10493`. Improve Local Linear Embedding (LLE)
263+
- |Efficiency| Fixed :issue:`10493`. Improve Local Linear Embedding (LLE)
264264
that raised `MemoryError` exception when used with large inputs.
265265
:pr:`17997` by :user:`Bertrand Maisonneuve <bmaisonn>`.
266266

@@ -323,6 +323,14 @@ Changelog
323323
validity of the input is now delegated to the base estimator.
324324
:pr:`17233` by :user:`Zolisa Bleki <zoj613>`.
325325

326+
:mod:`sklearn.multioutput`
327+
..........................
328+
329+
- |Fix| A fix to accept tuples for the ``order`` parameter
330+
in :class:`multioutput.ClassifierChain`.
331+
:pr:`18124` by :user:`Gus Brocchini <boldloop>` and
332+
:user:`Amanda Dsouza <amy12xx>`.
333+
326334
:mod:`sklearn.naive_bayes`
327335
..........................
328336

sklearn/multioutput.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,9 @@ def fit(self, X, Y, **fit_params):
453453
random_state = check_random_state(self.random_state)
454454
check_array(X, accept_sparse=True)
455455
self.order_ = self.order
456+
if isinstance(self.order_, tuple):
457+
self.order_ = np.array(self.order_)
458+
456459
if self.order_ is None:
457460
self.order_ = np.array(range(Y.shape[1]))
458461
elif isinstance(self.order_, str):

sklearn/tests/test_multioutput.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,3 +601,28 @@ def fit(self, X, y, **fit_params):
601601

602602
for est in model.estimators_:
603603
assert est.sample_weight_ is weight
604+
605+
606+
@pytest.mark.parametrize("order_type", [list, np.array, tuple])
607+
def test_classifier_chain_tuple_order(order_type):
608+
X = [[1, 2, 3], [4, 5, 6], [1.5, 2.5, 3.5]]
609+
y = [[3, 2], [2, 3], [3, 2]]
610+
order = order_type([1, 0])
611+
612+
chain = ClassifierChain(RandomForestClassifier(), order=order)
613+
614+
chain.fit(X, y)
615+
X_test = [[1.5, 2.5, 3.5]]
616+
y_test = [[3, 2]]
617+
assert_array_almost_equal(chain.predict(X_test), y_test)
618+
619+
620+
def test_classifier_chain_tuple_invalid_order():
621+
X = [[1, 2, 3], [4, 5, 6], [1.5, 2.5, 3.5]]
622+
y = [[3, 2], [2, 3], [3, 2]]
623+
order = tuple([1, 2])
624+
625+
chain = ClassifierChain(RandomForestClassifier(), order=order)
626+
627+
with pytest.raises(ValueError, match='invalid order'):
628+
chain.fit(X, y)

0 commit comments

Comments
 (0)