Skip to content

Commit a7b2f63

Browse files
committed
unittest: Add discover function.
1 parent 9d9ca3d commit a7b2f63

File tree

4 files changed

+141
-15
lines changed

4 files changed

+141
-15
lines changed

python-stdlib/unittest/metadata.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
srctype = micropython-lib
22
type = module
33
version = 0.3.2
4+
depends = argparse, fnmatch

python-stdlib/unittest/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@
2121
license="MIT",
2222
cmdclass={"sdist": sdist_upip.sdist},
2323
py_modules=["unittest"],
24+
install_requires=["micropython-argparse", "micropython-fnmatch"],
2425
)

python-stdlib/unittest/unittest.py

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
import uos
23

34
try:
45
import io
@@ -280,6 +281,15 @@ def __repr__(self):
280281
self.failuresNum,
281282
)
282283

284+
def __add__(self, other):
285+
self.errorsNum += other.errorsNum
286+
self.failuresNum += other.failuresNum
287+
self.skippedNum += other.skippedNum
288+
self.testsRun += other.testsRun
289+
self.errors.extend(other.errors)
290+
self.failures.extend(other.failures)
291+
return self
292+
283293

284294
def capture_exc(e):
285295
buf = io.StringIO()
@@ -290,7 +300,6 @@ def capture_exc(e):
290300
return buf.getvalue()
291301

292302

293-
# TODO: Uncompliant
294303
def run_suite(c, test_result, suite_name=""):
295304
if isinstance(c, TestSuite):
296305
c.run(test_result)
@@ -343,7 +352,7 @@ def run_one(m):
343352
return
344353

345354
for name in dir(o):
346-
if name.startswith("test_"):
355+
if name.startswith("test"):
347356
m = getattr(o, name)
348357
if not callable(m):
349358
continue
@@ -356,20 +365,65 @@ def run_one(m):
356365
return exceptions
357366

358367

368+
def _test_cases(mod):
369+
for tn in dir(mod):
370+
c = getattr(mod, tn)
371+
if isinstance(c, object) and isinstance(c, type) and issubclass(c, TestCase):
372+
yield c
373+
elif tn.startswith("test_") and callable(c):
374+
yield c
375+
376+
377+
def run_module(runner, module, path, top):
378+
sys_path_initial = sys.path[:]
379+
# Add script dir and top dir to import path
380+
sys.path.insert(0, str(path))
381+
if top:
382+
sys.path.insert(1, top)
383+
try:
384+
suite = TestSuite(module)
385+
m = __import__(module) if isinstance(module, str) else module
386+
for c in _test_cases(m):
387+
suite.addTest(c)
388+
result = runner.run(suite)
389+
return result
390+
391+
finally:
392+
sys.path[:] = sys_path_initial
393+
394+
395+
def discover(runner: TestRunner):
396+
from unittest_discover import discover
397+
398+
return discover(runner=runner)
399+
400+
359401
def main(module="__main__"):
360-
def test_cases(m):
361-
for tn in dir(m):
362-
c = getattr(m, tn)
363-
if isinstance(c, object) and isinstance(c, type) and issubclass(c, TestCase):
364-
yield c
365-
elif tn.startswith("test_") and callable(c):
366-
yield c
367-
368-
m = __import__(module) if isinstance(module, str) else module
369-
suite = TestSuite(m.__name__)
370-
for c in test_cases(m):
371-
suite.addTest(c)
372402
runner = TestRunner()
373-
result = runner.run(suite)
403+
404+
if len(sys.argv) <= 1:
405+
result = discover(runner)
406+
elif sys.argv[0].split(".")[0] == "unittest" and sys.argv[1] == "discover":
407+
result = discover(runner)
408+
else:
409+
for test_spec in sys.argv[1:]:
410+
try:
411+
uos.stat(test_spec)
412+
# test_spec is a local file, run it directly
413+
if "/" in test_spec:
414+
path, fname = test_spec.rsplit("/", 1)
415+
else:
416+
path, fname = ".", test_spec
417+
modname = fname.rsplit(".", 1)[0]
418+
result = run_module(runner, modname, path, None)
419+
420+
except OSError:
421+
# Not a file, treat as import name
422+
result = run_module(runner, test_spec, ".", None)
423+
374424
# Terminate with non zero return code in case of failures
375425
sys.exit(result.failuresNum or result.errorsNum)
426+
427+
428+
if __name__ == "__main__":
429+
main()
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import argparse
2+
import sys
3+
import uos
4+
from fnmatch import fnmatch
5+
6+
from unittest import TestRunner, TestResult, run_module
7+
8+
9+
def discover(runner: TestRunner):
10+
"""
11+
Implements discover function inspired by https://docs.python.org/3/library/unittest.html#test-discovery
12+
"""
13+
parser = argparse.ArgumentParser()
14+
# parser.add_argument(
15+
# "-v",
16+
# "--verbose",
17+
# action="store_true",
18+
# help="Verbose output",
19+
# )
20+
parser.add_argument(
21+
"-s",
22+
"--start-directory",
23+
dest="start",
24+
default=".",
25+
help="Directory to start discovery",
26+
)
27+
parser.add_argument(
28+
"-p",
29+
"--pattern ",
30+
dest="pattern",
31+
default="test*.py",
32+
help="Pattern to match test files",
33+
)
34+
parser.add_argument(
35+
"-t",
36+
"--top-level-directory",
37+
dest="top",
38+
help="Top level directory of project (defaults to start directory)",
39+
)
40+
args = parser.parse_args(args=sys.argv[2:])
41+
42+
path = args.start
43+
top = args.top or path
44+
45+
return run_all_in_dir(
46+
runner=runner,
47+
path=path,
48+
pattern=args.pattern,
49+
top=top,
50+
)
51+
52+
53+
def run_all_in_dir(runner: TestRunner, path: str, pattern: str, top: str):
54+
DIR_TYPE = 0x4000
55+
56+
result = TestResult()
57+
for fname, type, *_ in uos.ilistdir(path):
58+
if fname in ("..", "."):
59+
continue
60+
if type == DIR_TYPE:
61+
result += run_all_in_dir(
62+
runner=runner,
63+
path="/".join((path, fname)),
64+
pattern=pattern,
65+
top=top,
66+
)
67+
if fnmatch(fname, pattern):
68+
modname = fname[: fname.rfind(".")]
69+
result += run_module(runner, modname, path, top)
70+
return result

0 commit comments

Comments
 (0)