Skip to content

Commit ee558b5

Browse files
committed
Close DB connections when testing async functionality.
1 parent 75a62b7 commit ee558b5

File tree

2 files changed

+41
-9
lines changed

2 files changed

+41
-9
lines changed

tests/panels/test_sql.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818

1919
from ..base import BaseTestCase
2020
from ..models import PostgresJSON
21+
from ..sync import database_sync_to_async
22+
23+
24+
def sql_call(use_iterator=False):
25+
qs = User.objects.all()
26+
if use_iterator:
27+
qs = qs.iterator()
28+
return list(qs)
2129

2230

2331
class SQLPanelTestCase(BaseTestCase):
@@ -32,7 +40,7 @@ def test_disabled(self):
3240
def test_recording(self):
3341
self.assertEqual(len(self.panel._queries), 0)
3442

35-
list(User.objects.all())
43+
sql_call()
3644

3745
# ensure query was logged
3846
self.assertEqual(len(self.panel._queries), 1)
@@ -51,7 +59,7 @@ def test_recording(self):
5159
def test_recording_chunked_cursor(self):
5260
self.assertEqual(len(self.panel._queries), 0)
5361

54-
list(User.objects.all().iterator())
62+
sql_call(use_iterator=True)
5563

5664
# ensure query was logged
5765
self.assertEqual(len(self.panel._queries), 1)
@@ -61,7 +69,7 @@ def test_recording_chunked_cursor(self):
6169
wraps=sql_tracking.NormalCursorWrapper,
6270
)
6371
def test_cursor_wrapper_singleton(self, mock_wrapper):
64-
list(User.objects.all())
72+
sql_call()
6573

6674
# ensure that cursor wrapping is applied only once
6775
self.assertEqual(mock_wrapper.call_count, 1)
@@ -71,7 +79,7 @@ def test_cursor_wrapper_singleton(self, mock_wrapper):
7179
wraps=sql_tracking.NormalCursorWrapper,
7280
)
7381
def test_chunked_cursor_wrapper_singleton(self, mock_wrapper):
74-
list(User.objects.all().iterator())
82+
sql_call(use_iterator=True)
7583

7684
# ensure that cursor wrapping is applied only once
7785
self.assertEqual(mock_wrapper.call_count, 1)
@@ -81,7 +89,7 @@ def test_chunked_cursor_wrapper_singleton(self, mock_wrapper):
8189
wraps=sql_tracking.NormalCursorWrapper,
8290
)
8391
async def test_cursor_wrapper_async(self, mock_wrapper):
84-
await sync_to_async(list)(User.objects.all())
92+
await sync_to_async(sql_call)()
8593

8694
self.assertEqual(mock_wrapper.call_count, 1)
8795

@@ -91,11 +99,13 @@ async def test_cursor_wrapper_async(self, mock_wrapper):
9199
)
92100
async def test_cursor_wrapper_asyncio_ctx(self, mock_wrapper):
93101
self.assertTrue(sql_tracking.recording.get())
94-
await sync_to_async(list)(User.objects.all())
102+
await sync_to_async(sql_call)()
95103

96104
async def task():
97105
sql_tracking.recording.set(False)
98-
await sync_to_async(list)(User.objects.all())
106+
# Calling this in another context requires the db connections
107+
# to be closed properly.
108+
await database_sync_to_async(sql_call)()
99109

100110
# Ensure this is called in another context
101111
await asyncio.create_task(task())
@@ -106,7 +116,7 @@ async def task():
106116
def test_generate_server_timing(self):
107117
self.assertEqual(len(self.panel._queries), 0)
108118

109-
list(User.objects.all())
119+
sql_call()
110120

111121
response = self.panel.process_request(self.request)
112122
self.panel.generate_stats(self.request, response)
@@ -372,7 +382,7 @@ def test_disable_stacktraces(self):
372382
self.assertEqual(len(self.panel._queries), 0)
373383

374384
with self.settings(DEBUG_TOOLBAR_CONFIG={"ENABLE_STACKTRACES": False}):
375-
list(User.objects.all())
385+
sql_call()
376386

377387
# ensure query was logged
378388
self.assertEqual(len(self.panel._queries), 1)

tests/sync.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""
2+
Taken from channels.db
3+
"""
4+
from asgiref.sync import SyncToAsync
5+
from django.db import close_old_connections
6+
7+
8+
class DatabaseSyncToAsync(SyncToAsync):
9+
"""
10+
SyncToAsync version that cleans up old database connections when it exits.
11+
"""
12+
13+
def thread_handler(self, loop, *args, **kwargs):
14+
close_old_connections()
15+
try:
16+
return super().thread_handler(loop, *args, **kwargs)
17+
finally:
18+
close_old_connections()
19+
20+
21+
# The class is TitleCased, but we want to encourage use as a callable/decorator
22+
database_sync_to_async = DatabaseSyncToAsync

0 commit comments

Comments
 (0)