diff --git a/openai/__init__.py b/openai/__init__.py index 2b184226df..76b9f322e1 100644 --- a/openai/__init__.py +++ b/openai/__init__.py @@ -26,6 +26,7 @@ if TYPE_CHECKING: from aiohttp import ClientSession + import requests api_key = os.environ.get("OPENAI_API_KEY") # Path of a file with an API key, whose contents can change. Supercedes @@ -47,6 +48,9 @@ debug = False log = None # Set to either 'debug' or 'info', controls console logging +requestssession: ContextVar[Optional["requests.Session"]] = ContextVar( + "requests-session", default=None +) aiosession: ContextVar[Optional["ClientSession"]] = ContextVar( "aiohttp-session", default=None ) # Acts as a global aiohttp ClientSession that reuses connections. diff --git a/openai/api_requestor.py b/openai/api_requestor.py index 64e55e82ef..117ea83d0a 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -35,7 +35,7 @@ MAX_CONNECTION_RETRIES = 2 # Has one attribute per thread, 'session'. -_thread_context = threading.local() +# _thread_context = threading.local() def _build_api_url(http://github.com/url,%20query): @@ -510,10 +510,12 @@ def request_raw( url, supplied_headers, method, params, files, request_id ) - if not hasattr(_thread_context, "session"): - _thread_context.session = _make_session() + session = openai.requestssession.get() + if not session: + session = _make_session() + openai.requestssession.set(session) try: - result = _thread_context.session.request( + result = session.request( method, abs_url, headers=headers, diff --git a/openai/easyaz.py b/openai/easyaz.py new file mode 100644 index 0000000000..65fa5213e3 --- /dev/null +++ b/openai/easyaz.py @@ -0,0 +1,57 @@ +import logging +import typing + +if typing.TYPE_CHECKING: + from azure.core.credentials import TokenCredential + +try: + from requests import PreparedRequest, Response, Session + from requests.adapters import HTTPAdapter +except ModuleNotFoundError: + print("You have to install the `requests` library (pip install requests) in order to use easyaz.requests") + exit(-1) + +import azure.identity +import openai + +log = logging.getLogger(__name__) + +def init(endpoint: str, *, credential=None, api_version='2022-12-01'): + openai.api_type = 'azure_ad' + openai.api_key = 'dummy' + openai.api_base = endpoint + openai.api_version = api_version + + if not credential: + credential = azure.identity.DefaultAzureCredential() + + session = Session() + adapter = AzHttpAdapter(credential=credential, scopes=[ "https://cognitiveservices.azure.com/.default" ]) + session.mount(endpoint, adapter) + openai.requestssession.set(session) + + +class AzHttpAdapter(HTTPAdapter): + + def __init__(self, *, credential: "TokenCredential", scopes: list[str] | str, **kwargs: typing.Any): + super().__init__(**kwargs) + self.credential = credential + self.scopes = [ scopes ] if isinstance(scopes, str) else scopes + self.max_recurse = 1 + + def send(self, request: PreparedRequest, stream: bool = ..., timeout: None | float | tuple[float, float] | tuple[float, None] = ..., verify: bool | str = ..., cert: None | bytes | str | tuple[bytes | str, bytes | str] = ..., proxies: typing.Mapping[str, str] | None = ..., *, recurse:int=0) -> Response: + initial_response = super().send(request, stream, timeout, verify, cert, proxies) + if initial_response.status_code != 401 or recurse > self.max_recurse: + # Only do the auth dance if we are challenged... + return initial_response + + log.info('Received 401 response from service - grabbing a token!') + + # Drain response + initial_response.content + + # Fill in a token + new_request = request.copy() + access_token = self.credential.get_token(*self.scopes) # TODO: sniff out claims from response + new_request.headers['Authorization'] = 'Bearer ' + access_token.token + return self.send(new_request, stream, timeout, verify, cert, proxies, recurse=recurse + 1) \ No newline at end of file