Skip to content

Commit 766fb56

Browse files
wlamondlarsmans
authored andcommitted
FIX allow ndim>2 in shuffle
Fixes scikit-learn#3694.
1 parent 7426e6a commit 766fb56

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

sklearn/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,8 @@ def resample(*arrays, **options):
242242
max_n_samples, n_samples))
243243

244244
check_consistent_length(*arrays)
245-
arrays = [check_array(x, accept_sparse='csr', ensure_2d=False)
246-
for x in arrays]
245+
arrays = [check_array(x, accept_sparse='csr', ensure_2d=False,
246+
allow_nd=True) for x in arrays]
247247

248248
if replace:
249249
indices = random_state.randint(0, n_samples, size=(max_n_samples,))

sklearn/utils/tests/test_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.utils import safe_mask
1515
from sklearn.utils import column_or_1d
1616
from sklearn.utils import safe_indexing
17+
from sklearn.utils import shuffle
1718
from sklearn.utils.extmath import pinvh
1819
from sklearn.utils.mocking import MockDataFrame
1920

@@ -175,3 +176,8 @@ def test_safe_indexing_mock_pandas():
175176
X_df_indexed = safe_indexing(X_df, inds)
176177
X_indexed = safe_indexing(X_df, inds)
177178
assert_array_equal(np.array(X_df_indexed), X_indexed)
179+
180+
181+
def test_shuffle_on_ndim_equals_three():
182+
A = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # A.shape = (2,2,2)
183+
shuffle(A) # shouldn't raise a ValueError for dim = 3

0 commit comments

Comments
 (0)