Skip to content

Commit f6ede6e

Browse files
allow multiple waiters
1 parent de388c0 commit f6ede6e

File tree

3 files changed

+37
-18
lines changed

3 files changed

+37
-18
lines changed

Lib/asyncio/streams.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
33
'open_connection', 'start_server')
44

5+
import collections
56
import socket
67
import sys
78
import weakref
@@ -128,7 +129,7 @@ def __init__(self, loop=None):
128129
else:
129130
self._loop = loop
130131
self._paused = False
131-
self._drain_waiter = None
132+
self._drain_waiters = collections.deque()
132133
self._connection_lost = False
133134

134135
def pause_writing(self):
@@ -143,9 +144,8 @@ def resume_writing(self):
143144
if self._loop.get_debug():
144145
logger.debug("%r resumes writing", self)
145146

146-
waiter = self._drain_waiter
147-
if waiter is not None:
148-
self._drain_waiter = None
147+
while self._drain_waiters:
148+
waiter = self._drain_waiters.popleft()
149149
if not waiter.done():
150150
waiter.set_result(None)
151151

@@ -154,27 +154,26 @@ def connection_lost(self, exc):
154154
# Wake up the writer if currently paused.
155155
if not self._paused:
156156
return
157-
waiter = self._drain_waiter
158-
if waiter is None:
159-
return
160-
self._drain_waiter = None
161-
if waiter.done():
162-
return
163-
if exc is None:
164-
waiter.set_result(None)
165-
else:
166-
waiter.set_exception(exc)
157+
158+
while self._drain_waiters:
159+
waiter = self._drain_waiters.popleft()
160+
if not waiter.done():
161+
if exc is None:
162+
waiter.set_result(None)
163+
else:
164+
waiter.set_exception(exc)
167165

168166
async def _drain_helper(self):
169167
if self._connection_lost:
170168
raise ConnectionResetError('Connection lost')
171169
if not self._paused:
172170
return
173-
waiter = self._drain_waiter
174-
assert waiter is None or waiter.cancelled()
175171
waiter = self._loop.create_future()
176-
self._drain_waiter = waiter
177-
await waiter
172+
self._drain_waiters.append(waiter)
173+
try:
174+
await waiter
175+
finally:
176+
self._drain_waiters.remove(waiter)
178177

179178
def _get_close_waiter(self, stream):
180179
raise NotImplementedError

Lib/test/test_asyncio/test_streams.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,25 @@ def test_streamreaderprotocol_constructor_use_global_loop(self):
864864
self.assertEqual(cm.filename, __file__)
865865
self.assertIs(protocol._loop, self.loop)
866866

867+
def test_multiple_drain(self):
868+
# See https://github.com/python/cpython/issues/74116
869+
drained = 0
870+
871+
async def drainer(stream):
872+
nonlocal drained
873+
await stream._drain_helper()
874+
drained += 1
875+
876+
async def main():
877+
loop = asyncio.get_running_loop()
878+
stream = asyncio.streams.FlowControlMixin(loop)
879+
stream.pause_writing()
880+
loop.call_later(0.1, stream.resume_writing)
881+
await asyncio.gather(*[drainer(stream) for _ in range(10)])
882+
self.assertEqual(drained, 10)
883+
884+
self.loop.run_until_complete(main())
885+
867886
def test_drain_raises(self):
868887
# See http://bugs.python.org/issue25441
869888

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix :meth:`asyncio.StreamWriter.drain` to be awaited concurrently by multiple tasks. Patch by Kumar Aditya.

0 commit comments

Comments
 (0)