Skip to content

Commit c556584

Browse files
allow for user passed requests.Session (openai#390)
1 parent 96e7642 commit c556584

File tree

3 files changed

+40
-1
lines changed

3 files changed

+40
-1
lines changed

openai/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import os
66
import sys
7-
from typing import TYPE_CHECKING, Optional
7+
from typing import TYPE_CHECKING, Optional, Union, Callable
88

99
from contextvars import ContextVar
1010

@@ -36,6 +36,7 @@
3636
from openai.version import VERSION
3737

3838
if TYPE_CHECKING:
39+
import requests
3940
from aiohttp import ClientSession
4041

4142
api_key = os.environ.get("OPENAI_API_KEY")
@@ -58,6 +59,10 @@
5859
debug = False
5960
log = None # Set to either 'debug' or 'info', controls console logging
6061

62+
requestssession: Optional[
63+
Union["requests.Session", Callable[[], "requests.Session"]]
64+
] = None # Provide a requests.Session or Session factory.
65+
6166
aiosession: ContextVar[Optional["ClientSession"]] = ContextVar(
6267
"aiohttp-session", default=None
6368
) # Acts as a global aiohttp ClientSession that reuses connections.

openai/api_requestor.py

+4
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ def _aiohttp_proxies_arg(proxy) -> Optional[str]:
7676

7777

7878
def _make_session() -> requests.Session:
79+
if openai.requestssession:
80+
if isinstance(openai.requestssession, requests.Session):
81+
return openai.requestssession
82+
return openai.requestssession()
7983
if not openai.verify_ssl_certs:
8084
warnings.warn("verify_ssl_certs is ignored; openai always verifies.")
8185
s = requests.Session()

openai/tests/test_endpoints.py

+30
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33

44
import pytest
5+
import requests
56

67
import openai
78
from openai import error
@@ -86,3 +87,32 @@ def test_timeout_does_not_error():
8687
model="ada",
8788
request_timeout=10,
8889
)
90+
91+
92+
def test_user_session():
93+
with requests.Session() as session:
94+
openai.requestssession = session
95+
96+
completion = openai.Completion.create(
97+
prompt="hello world",
98+
model="ada",
99+
)
100+
assert completion
101+
102+
103+
def test_user_session_factory():
104+
def factory():
105+
session = requests.Session()
106+
session.mount(
107+
"https://",
108+
requests.adapters.HTTPAdapter(max_retries=4),
109+
)
110+
return session
111+
112+
openai.requestssession = factory
113+
114+
completion = openai.Completion.create(
115+
prompt="hello world",
116+
model="ada",
117+
)
118+
assert completion

0 commit comments

Comments
 (0)