18
18
19
19
from ..base import BaseTestCase
20
20
from ..models import PostgresJSON
21
+ from ..sync import database_sync_to_async
22
+
23
+
24
+ def sql_call (use_iterator = False ):
25
+ qs = User .objects .all ()
26
+ if use_iterator :
27
+ qs = qs .iterator ()
28
+ return list (qs )
21
29
22
30
23
31
class SQLPanelTestCase (BaseTestCase ):
@@ -32,7 +40,7 @@ def test_disabled(self):
32
40
def test_recording (self ):
33
41
self .assertEqual (len (self .panel ._queries ), 0 )
34
42
35
- list ( User . objects . all () )
43
+ sql_call ( )
36
44
37
45
# ensure query was logged
38
46
self .assertEqual (len (self .panel ._queries ), 1 )
@@ -51,7 +59,7 @@ def test_recording(self):
51
59
def test_recording_chunked_cursor (self ):
52
60
self .assertEqual (len (self .panel ._queries ), 0 )
53
61
54
- list ( User . objects . all (). iterator () )
62
+ sql_call ( use_iterator = True )
55
63
56
64
# ensure query was logged
57
65
self .assertEqual (len (self .panel ._queries ), 1 )
@@ -61,7 +69,7 @@ def test_recording_chunked_cursor(self):
61
69
wraps = sql_tracking .NormalCursorWrapper ,
62
70
)
63
71
def test_cursor_wrapper_singleton (self , mock_wrapper ):
64
- list ( User . objects . all () )
72
+ sql_call ( )
65
73
66
74
# ensure that cursor wrapping is applied only once
67
75
self .assertEqual (mock_wrapper .call_count , 1 )
@@ -71,7 +79,7 @@ def test_cursor_wrapper_singleton(self, mock_wrapper):
71
79
wraps = sql_tracking .NormalCursorWrapper ,
72
80
)
73
81
def test_chunked_cursor_wrapper_singleton (self , mock_wrapper ):
74
- list ( User . objects . all (). iterator () )
82
+ sql_call ( use_iterator = True )
75
83
76
84
# ensure that cursor wrapping is applied only once
77
85
self .assertEqual (mock_wrapper .call_count , 1 )
@@ -81,7 +89,7 @@ def test_chunked_cursor_wrapper_singleton(self, mock_wrapper):
81
89
wraps = sql_tracking .NormalCursorWrapper ,
82
90
)
83
91
async def test_cursor_wrapper_async (self , mock_wrapper ):
84
- await sync_to_async (list )( User . objects . all () )
92
+ await sync_to_async (sql_call )( )
85
93
86
94
self .assertEqual (mock_wrapper .call_count , 1 )
87
95
@@ -91,11 +99,13 @@ async def test_cursor_wrapper_async(self, mock_wrapper):
91
99
)
92
100
async def test_cursor_wrapper_asyncio_ctx (self , mock_wrapper ):
93
101
self .assertTrue (sql_tracking .recording .get ())
94
- await sync_to_async (list )( User . objects . all () )
102
+ await sync_to_async (sql_call )( )
95
103
96
104
async def task ():
97
105
sql_tracking .recording .set (False )
98
- await sync_to_async (list )(User .objects .all ())
106
+ # Calling this in another context requires the db connections
107
+ # to be closed properly.
108
+ await database_sync_to_async (sql_call )()
99
109
100
110
# Ensure this is called in another context
101
111
await asyncio .create_task (task ())
@@ -106,7 +116,7 @@ async def task():
106
116
def test_generate_server_timing (self ):
107
117
self .assertEqual (len (self .panel ._queries ), 0 )
108
118
109
- list ( User . objects . all () )
119
+ sql_call ( )
110
120
111
121
response = self .panel .process_request (self .request )
112
122
self .panel .generate_stats (self .request , response )
@@ -372,7 +382,7 @@ def test_disable_stacktraces(self):
372
382
self .assertEqual (len (self .panel ._queries ), 0 )
373
383
374
384
with self .settings (DEBUG_TOOLBAR_CONFIG = {"ENABLE_STACKTRACES" : False }):
375
- list ( User . objects . all () )
385
+ sql_call ( )
376
386
377
387
# ensure query was logged
378
388
self .assertEqual (len (self .panel ._queries ), 1 )
0 commit comments