Skip to content

Commit 6ce198c

Browse files
facaiyagramfort
authored andcommitted
[MRG] scikit-learn#6581 n_samples of utils.resample can be more when replace is True
* scikit-learn#6581 n_samples can be more when replace is True * more compact code
1 parent 33396b1 commit 6ce198c

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

sklearn/utils/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ def resample(*arrays, **options):
133133
n_samples : int, None by default
134134
Number of samples to generate. If left to None this is
135135
automatically set to the first dimension of the arrays.
136+
If replace is False it should not be larger than the length of
137+
arrays.
136138
137139
random_state : int or RandomState instance
138140
Control the shuffling for reproducible behavior.
@@ -194,10 +196,10 @@ def resample(*arrays, **options):
194196

195197
if max_n_samples is None:
196198
max_n_samples = n_samples
197-
198-
if max_n_samples > n_samples:
199-
raise ValueError("Cannot sample %d out of arrays with dim %d" % (
200-
max_n_samples, n_samples))
199+
elif (max_n_samples > n_samples) and (not replace):
200+
raise ValueError("Cannot sample %d out of arrays with dim %d"
201+
"when replace is False" % (max_n_samples,
202+
n_samples))
201203

202204
check_consistent_length(*arrays)
203205

sklearn/utils/tests/test_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,6 @@ def test_make_rng():
4242
assert_raises(ValueError, check_random_state, "some invalid seed")
4343

4444

45-
def test_resample_noarg():
46-
# Border case not worth mentioning in doctests
47-
assert_true(resample() is None)
48-
49-
5045
def test_deprecated():
5146
# Test whether the deprecated decorator issues appropriate warnings
5247
# Copied almost verbatim from http://docs.python.org/library/warnings.html
@@ -84,11 +79,17 @@ class Ham(object):
8479
assert_true("deprecated" in str(w[0].message).lower())
8580

8681

87-
def test_resample_value_errors():
82+
def test_resample():
83+
# Border case not worth mentioning in doctests
84+
assert_true(resample() is None)
85+
8886
# Check that invalid arguments yield ValueError
8987
assert_raises(ValueError, resample, [0], [0, 1])
90-
assert_raises(ValueError, resample, [0, 1], [0, 1], n_samples=3)
88+
assert_raises(ValueError, resample, [0, 1], [0, 1],
89+
replace=False, n_samples=3)
9190
assert_raises(ValueError, resample, [0, 1], [0, 1], meaning_of_life=42)
91+
# Issue:6581, n_samples can be more when replace is True (default).
92+
assert_equal(len(resample([1, 2], n_samples=5)), 5)
9293

9394

9495
def test_safe_mask():

0 commit comments

Comments
 (0)