Skip to content

Commit 7af43ce

Browse files
authored
Fix API requestor hanging when not using a global session (openai#167)
1 parent fb4b672 commit 7af43ce

File tree

2 files changed

+47
-28
lines changed

2 files changed

+47
-28
lines changed

openai/api_requestor.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,18 @@
44
import sys
55
import threading
66
import warnings
7+
from contextlib import asynccontextmanager
78
from json import JSONDecodeError
8-
from typing import AsyncGenerator, Dict, Iterator, Optional, Tuple, Union, overload
9+
from typing import (
10+
AsyncGenerator,
11+
AsyncIterator,
12+
Dict,
13+
Iterator,
14+
Optional,
15+
Tuple,
16+
Union,
17+
overload,
18+
)
919
from urllib.parse import urlencode, urlsplit, urlunsplit
1020

1121
import aiohttp
@@ -284,17 +294,19 @@ async def arequest(
284294
request_id: Optional[str] = None,
285295
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
286296
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
287-
result = await self.arequest_raw(
288-
method.lower(),
289-
url,
290-
params=params,
291-
supplied_headers=headers,
292-
files=files,
293-
request_id=request_id,
294-
request_timeout=request_timeout,
295-
)
296-
resp, got_stream = await self._interpret_async_response(result, stream)
297-
return resp, got_stream, self.api_key
297+
async with aiohttp_session() as session:
298+
result = await self.arequest_raw(
299+
method.lower(),
300+
url,
301+
session,
302+
params=params,
303+
supplied_headers=headers,
304+
files=files,
305+
request_id=request_id,
306+
request_timeout=request_timeout,
307+
)
308+
resp, got_stream = await self._interpret_async_response(result, stream)
309+
return resp, got_stream, self.api_key
298310

299311
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
300312
try:
@@ -514,6 +526,7 @@ async def arequest_raw(
514526
self,
515527
method,
516528
url,
529+
session,
517530
*,
518531
params=None,
519532
supplied_headers: Optional[Dict[str, str]] = None,
@@ -534,7 +547,6 @@ async def arequest_raw(
534547
timeout = aiohttp.ClientTimeout(
535548
total=request_timeout if request_timeout else TIMEOUT_SECS
536549
)
537-
user_set_session = openai.aiosession.get()
538550

539551
if files:
540552
# TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
@@ -552,11 +564,7 @@ async def arequest_raw(
552564
"timeout": timeout,
553565
}
554566
try:
555-
if user_set_session:
556-
result = await user_set_session.request(**request_kwargs)
557-
else:
558-
async with aiohttp.ClientSession() as session:
559-
result = await session.request(**request_kwargs)
567+
result = await session.request(**request_kwargs)
560568
util.log_info(
561569
"OpenAI API response",
562570
path=abs_url,
@@ -648,3 +656,13 @@ def _interpret_response_line(
648656
rbody, rcode, resp.data, rheaders, stream_error=stream_error
649657
)
650658
return resp
659+
660+
661+
@asynccontextmanager
662+
async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]:
663+
user_set_session = openai.aiosession.get()
664+
if user_set_session:
665+
yield user_set_session
666+
else:
667+
async with aiohttp.ClientSession() as session:
668+
yield session

openai/api_resources/file.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -192,16 +192,17 @@ async def adownload(
192192
id, api_key, api_base, api_type, api_version, organization
193193
)
194194

195-
result = await requestor.arequest_raw("get", url)
196-
if not 200 <= result.status < 300:
197-
raise requestor.handle_error_response(
198-
result.content,
199-
result.status,
200-
json.loads(cast(bytes, result.content)),
201-
result.headers,
202-
stream_error=False,
203-
)
204-
return result.content
195+
async with api_requestor.aiohttp_session() as session:
196+
result = await requestor.arequest_raw("get", url, session)
197+
if not 200 <= result.status < 300:
198+
raise requestor.handle_error_response(
199+
result.content,
200+
result.status,
201+
json.loads(cast(bytes, result.content)),
202+
result.headers,
203+
stream_error=False,
204+
)
205+
return result.content
205206

206207
@classmethod
207208
def __find_matching_files(cls, name, all_files, purpose):

0 commit comments

Comments
 (0)