Skip to content

Commit 2e9f29e

Browse files
GH-74116: Allow multiple drain waiters for asyncio.StreamWriter (GH-94705) (#96395)
(cherry picked from commit e5b2453) Co-authored-by: Kumar Aditya <[email protected]> Co-authored-by: Kumar Aditya <[email protected]>
1 parent 126ec34 commit 2e9f29e

File tree

3 files changed

+36
-19
lines changed

3 files changed

+36
-19
lines changed

Lib/asyncio/streams.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
33
'open_connection', 'start_server')
44

5+
import collections
56
import socket
67
import sys
78
import warnings
@@ -129,7 +130,7 @@ def __init__(self, loop=None):
129130
else:
130131
self._loop = loop
131132
self._paused = False
132-
self._drain_waiter = None
133+
self._drain_waiters = collections.deque()
133134
self._connection_lost = False
134135

135136
def pause_writing(self):
@@ -144,38 +145,34 @@ def resume_writing(self):
144145
if self._loop.get_debug():
145146
logger.debug("%r resumes writing", self)
146147

147-
waiter = self._drain_waiter
148-
if waiter is not None:
149-
self._drain_waiter = None
148+
for waiter in self._drain_waiters:
150149
if not waiter.done():
151150
waiter.set_result(None)
152151

153152
def connection_lost(self, exc):
154153
self._connection_lost = True
155-
# Wake up the writer if currently paused.
154+
# Wake up the writer(s) if currently paused.
156155
if not self._paused:
157156
return
158-
waiter = self._drain_waiter
159-
if waiter is None:
160-
return
161-
self._drain_waiter = None
162-
if waiter.done():
163-
return
164-
if exc is None:
165-
waiter.set_result(None)
166-
else:
167-
waiter.set_exception(exc)
157+
158+
for waiter in self._drain_waiters:
159+
if not waiter.done():
160+
if exc is None:
161+
waiter.set_result(None)
162+
else:
163+
waiter.set_exception(exc)
168164

169165
async def _drain_helper(self):
170166
if self._connection_lost:
171167
raise ConnectionResetError('Connection lost')
172168
if not self._paused:
173169
return
174-
waiter = self._drain_waiter
175-
assert waiter is None or waiter.cancelled()
176170
waiter = self._loop.create_future()
177-
self._drain_waiter = waiter
178-
await waiter
171+
self._drain_waiters.append(waiter)
172+
try:
173+
await waiter
174+
finally:
175+
self._drain_waiters.remove(waiter)
179176

180177
def _get_close_waiter(self, stream):
181178
raise NotImplementedError

Lib/test/test_asyncio/test_streams.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,25 @@ def test_streamreaderprotocol_constructor_use_global_loop(self):
864864
self.assertEqual(cm.filename, __file__)
865865
self.assertIs(protocol._loop, self.loop)
866866

867+
def test_multiple_drain(self):
868+
# See https://github.com/python/cpython/issues/74116
869+
drained = 0
870+
871+
async def drainer(stream):
872+
nonlocal drained
873+
await stream._drain_helper()
874+
drained += 1
875+
876+
async def main():
877+
loop = asyncio.get_running_loop()
878+
stream = asyncio.streams.FlowControlMixin(loop)
879+
stream.pause_writing()
880+
loop.call_later(0.1, stream.resume_writing)
881+
await asyncio.gather(*[drainer(stream) for _ in range(10)])
882+
self.assertEqual(drained, 10)
883+
884+
self.loop.run_until_complete(main())
885+
867886
def test_drain_raises(self):
868887
# See http://bugs.python.org/issue25441
869888

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Allow :meth:`asyncio.StreamWriter.drain` to be awaited concurrently by multiple tasks. Patch by Kumar Aditya.

0 commit comments

Comments
 (0)