Skip to content

Commit a2e9ab3

Browse files
committed
extmod/asyncio: Support gather of tasks that finish early.
Adds support to asyncio.gather() for the case that one or more (or all) sub-tasks finish and/or raise an exception before the gather starts. Signed-off-by: Damien George <[email protected]>
1 parent 1e8cc6c commit a2e9ab3

File tree

3 files changed

+102
-17
lines changed

3 files changed

+102
-17
lines changed

extmod/asyncio/core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ def run_until_complete(main_task=None):
219219
elif t.state is None:
220220
# Task is already finished and nothing await'ed on the task,
221221
# so call the exception handler.
222+
223+
# Save exception raised by the coro for later use.
224+
t.data = exc
225+
226+
# Create exception context and call the exception handler.
222227
_exc_context["exception"] = exc
223228
_exc_context["future"] = t
224229
Loop.call_exception_handler(_exc_context)

extmod/asyncio/funcs.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@ def remove(t):
6363

6464
# async
6565
def gather(*aws, return_exceptions=False):
66-
if not aws:
67-
return []
68-
6966
def done(t, er):
7067
# Sub-task "t" has finished, with exception "er".
7168
nonlocal state
@@ -86,26 +83,39 @@ def done(t, er):
8683
# Gather waiting is done, schedule the main gather task.
8784
core._task_queue.push(gather_task)
8885

86+
# Prepare the sub-tasks for the gather.
87+
# The `state` variable counts the number of tasks to wait for, and can be negative
88+
# if the gather should not run at all (because a task already had an exception).
8989
ts = [core._promote_to_task(aw) for aw in aws]
90+
state = 0
9091
for i in range(len(ts)):
91-
if ts[i].state is not True:
92-
# Task is not running, gather not currently supported for this case.
92+
if ts[i].state is True:
93+
# Task is running, register the callback to call when the task is done.
94+
ts[i].state = done
95+
state += 1
96+
elif not ts[i].state:
97+
# Task finished already.
98+
if not isinstance(ts[i].data, StopIteration):
99+
# Task finished by raising an exception.
100+
if not return_exceptions:
101+
# Do not run this gather at all.
102+
state = -len(ts)
103+
else:
104+
# Task being waited on, gather not currently supported for this case.
93105
raise RuntimeError("can't gather")
94-
# Register the callback to call when the task is done.
95-
ts[i].state = done
96106

97107
# Set the state for execution of the gather.
98108
gather_task = core.cur_task
99-
state = len(ts)
100109
cancel_all = False
101110

102-
# Wait for the a sub-task to need attention.
103-
gather_task.data = _Remove
104-
try:
105-
yield
106-
except core.CancelledError as er:
107-
cancel_all = True
108-
state = er
111+
# Wait for a sub-task to need attention (if there are any to wait for).
112+
if state > 0:
113+
gather_task.data = _Remove
114+
try:
115+
yield
116+
except core.CancelledError as er:
117+
cancel_all = True
118+
state = er
109119

110120
# Clean up tasks.
111121
for i in range(len(ts)):
@@ -118,8 +128,13 @@ def done(t, er):
118128
# Sub-task ran to completion, get its return value.
119129
ts[i] = ts[i].data.value
120130
else:
121-
# Sub-task had an exception with return_exceptions==True, so get its exception.
122-
ts[i] = ts[i].data
131+
# Sub-task had an exception.
132+
if return_exceptions:
133+
# Get the sub-task exception to return in the list of return values.
134+
ts[i] = ts[i].data
135+
elif isinstance(state, int):
136+
# Raise the sub-task exception, if there is not already an exception to raise.
137+
state = ts[i].data
123138

124139
# Either this gather was cancelled, or one of the sub-tasks raised an exception with
125140
# return_exceptions==False, so reraise the exception here.
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Test asyncio.gather() when a task is already finished before the gather starts.
2+
3+
try:
4+
import asyncio
5+
except ImportError:
6+
print("SKIP")
7+
raise SystemExit
8+
9+
10+
# CPython and MicroPython differ in when they signal (and print) that a task raised an
11+
# uncaught exception. So define an empty custom_handler() to suppress this output.
12+
def custom_handler(loop, context):
13+
pass
14+
15+
16+
async def task_that_finishes_early(id, event, fail):
17+
print("task_that_finishes_early", id)
18+
event.set()
19+
if fail:
20+
raise ValueError("intentional exception", id)
21+
22+
23+
async def task_that_runs():
24+
for i in range(5):
25+
print("task_that_runs", i)
26+
await asyncio.sleep(0)
27+
28+
29+
async def main(start_task_that_runs, task_fail, return_exceptions):
30+
print("== start", start_task_that_runs, task_fail, return_exceptions)
31+
32+
# Set exception handler to suppress exception output.
33+
loop = asyncio.get_event_loop()
34+
loop.set_exception_handler(custom_handler)
35+
36+
# Create tasks.
37+
event_a = asyncio.Event()
38+
event_b = asyncio.Event()
39+
tasks = []
40+
if start_task_that_runs:
41+
tasks.append(asyncio.create_task(task_that_runs()))
42+
tasks.append(asyncio.create_task(task_that_finishes_early("a", event_a, task_fail)))
43+
tasks.append(asyncio.create_task(task_that_finishes_early("b", event_b, task_fail)))
44+
45+
# Make sure task_that_finishes_early() are both done, before calling gather().
46+
await event_a.wait()
47+
await event_b.wait()
48+
49+
# Gather the tasks.
50+
try:
51+
result = "complete", await asyncio.gather(*tasks, return_exceptions=return_exceptions)
52+
except Exception as er:
53+
result = "exception", er, start_task_that_runs and tasks[0].done()
54+
55+
# Wait for the final task to finish (if it was started).
56+
if start_task_that_runs:
57+
await tasks[0]
58+
59+
# Print results.
60+
print(result)
61+
62+
63+
# Run the test in the 8 different combinations of its arguments.
64+
for i in range(8):
65+
asyncio.run(main(bool(i & 4), bool(i & 2), bool(i & 1)))

0 commit comments

Comments
 (0)