Skip to content

Commit 39f831d

Browse files
ndawearjoly
authored andcommitted
utils.testing: add assert_greater_equal and assert_less_equal
1 parent 72b3da8 commit 39f831d

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

sklearn/tree/tests/test_tree.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sklearn.utils.testing import assert_in
1818
from sklearn.utils.testing import assert_raises
1919
from sklearn.utils.testing import assert_greater
20+
from sklearn.utils.testing import assert_greater_equal
2021
from sklearn.utils.testing import assert_less
2122
from sklearn.utils.testing import assert_true
2223
from sklearn.utils.testing import raises
@@ -481,11 +482,12 @@ def test_min_weight_fraction_leaf():
481482
node_weights = np.bincount(out, weights=weights)
482483
# drop inner nodes
483484
leaf_weights = node_weights[node_weights != 0]
484-
assert_true(np.min(leaf_weights) >=
485-
total_weight * est.min_weight_fraction_leaf,
486-
"Failed with {0} "
487-
"min_weight_fraction_leaf={1}".format(
488-
name, est.min_weight_fraction_leaf))
485+
assert_greater_equal(
486+
np.min(leaf_weights),
487+
total_weight * est.min_weight_fraction_leaf,
488+
"Failed with {0} "
489+
"min_weight_fraction_leaf={1}".format(
490+
name, est.min_weight_fraction_leaf))
489491

490492

491493
def test_pickle():

sklearn/utils/testing.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@
5252
__all__ = ["assert_equal", "assert_not_equal", "assert_raises",
5353
"assert_raises_regexp", "raises", "with_setup", "assert_true",
5454
"assert_false", "assert_almost_equal", "assert_array_equal",
55-
"assert_array_almost_equal", "assert_array_less"]
55+
"assert_array_almost_equal", "assert_array_less",
56+
"assert_less", "assert_less_equal",
57+
"assert_greater", "assert_greater_equal"]
5658

5759

5860
try:
@@ -103,6 +105,20 @@ def _assert_greater(a, b, msg=None):
103105
assert a > b, message
104106

105107

108+
def assert_less_equal(a, b, msg=None):
109+
message = "%r is not lower than or equal to %r" % (a, b)
110+
if msg is not None:
111+
message += ": " + msg
112+
assert a <= b, message
113+
114+
115+
def assert_greater_equal(a, b, msg=None):
116+
message = "%r is not greater than or equal to %r" % (a, b)
117+
if msg is not None:
118+
message += ": " + msg
119+
assert a >= b, message
120+
121+
106122
# To remove when we support numpy 1.7
107123
def assert_warns(warning_class, func, *args, **kw):
108124
"""Test that a certain warning occurs.

0 commit comments

Comments
 (0)