4
4
import sys
5
5
import threading
6
6
import warnings
7
+ from contextlib import asynccontextmanager
7
8
from json import JSONDecodeError
8
- from typing import AsyncGenerator , Dict , Iterator , Optional , Tuple , Union , overload
9
+ from typing import (
10
+ AsyncGenerator ,
11
+ AsyncIterator ,
12
+ Dict ,
13
+ Iterator ,
14
+ Optional ,
15
+ Tuple ,
16
+ Union ,
17
+ overload ,
18
+ )
9
19
from urllib .parse import urlencode , urlsplit , urlunsplit
10
20
11
21
import aiohttp
@@ -284,17 +294,19 @@ async def arequest(
284
294
request_id : Optional [str ] = None ,
285
295
request_timeout : Optional [Union [float , Tuple [float , float ]]] = None ,
286
296
) -> Tuple [Union [OpenAIResponse , AsyncGenerator [OpenAIResponse , None ]], bool , str ]:
287
- result = await self .arequest_raw (
288
- method .lower (),
289
- url ,
290
- params = params ,
291
- supplied_headers = headers ,
292
- files = files ,
293
- request_id = request_id ,
294
- request_timeout = request_timeout ,
295
- )
296
- resp , got_stream = await self ._interpret_async_response (result , stream )
297
- return resp , got_stream , self .api_key
297
+ async with aiohttp_session () as session :
298
+ result = await self .arequest_raw (
299
+ method .lower (),
300
+ url ,
301
+ session ,
302
+ params = params ,
303
+ supplied_headers = headers ,
304
+ files = files ,
305
+ request_id = request_id ,
306
+ request_timeout = request_timeout ,
307
+ )
308
+ resp , got_stream = await self ._interpret_async_response (result , stream )
309
+ return resp , got_stream , self .api_key
298
310
299
311
def handle_error_response (self , rbody , rcode , resp , rheaders , stream_error = False ):
300
312
try :
@@ -514,6 +526,7 @@ async def arequest_raw(
514
526
self ,
515
527
method ,
516
528
url ,
529
+ session ,
517
530
* ,
518
531
params = None ,
519
532
supplied_headers : Optional [Dict [str , str ]] = None ,
@@ -534,7 +547,6 @@ async def arequest_raw(
534
547
timeout = aiohttp .ClientTimeout (
535
548
total = request_timeout if request_timeout else TIMEOUT_SECS
536
549
)
537
- user_set_session = openai .aiosession .get ()
538
550
539
551
if files :
540
552
# TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
@@ -552,11 +564,7 @@ async def arequest_raw(
552
564
"timeout" : timeout ,
553
565
}
554
566
try :
555
- if user_set_session :
556
- result = await user_set_session .request (** request_kwargs )
557
- else :
558
- async with aiohttp .ClientSession () as session :
559
- result = await session .request (** request_kwargs )
567
+ result = await session .request (** request_kwargs )
560
568
util .log_info (
561
569
"OpenAI API response" ,
562
570
path = abs_url ,
@@ -648,3 +656,13 @@ def _interpret_response_line(
648
656
rbody , rcode , resp .data , rheaders , stream_error = stream_error
649
657
)
650
658
return resp
659
+
660
+
661
+ @asynccontextmanager
662
+ async def aiohttp_session () -> AsyncIterator [aiohttp .ClientSession ]:
663
+ user_set_session = openai .aiosession .get ()
664
+ if user_set_session :
665
+ yield user_set_session
666
+ else :
667
+ async with aiohttp .ClientSession () as session :
668
+ yield session
0 commit comments