Skip to content

Proof of concept on how to simplify using openai on azure #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

if TYPE_CHECKING:
from aiohttp import ClientSession
import requests

api_key = os.environ.get("OPENAI_API_KEY")
# Path of a file with an API key, whose contents can change. Supercedes
Expand All @@ -47,6 +48,9 @@
debug = False
log = None # Set to either 'debug' or 'info', controls console logging

requestssession: ContextVar[Optional["requests.Session"]] = ContextVar(
"requests-session", default=None
)
aiosession: ContextVar[Optional["ClientSession"]] = ContextVar(
"aiohttp-session", default=None
) # Acts as a global aiohttp ClientSession that reuses connections.
Expand Down
10 changes: 6 additions & 4 deletions openai/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
MAX_CONNECTION_RETRIES = 2

# Has one attribute per thread, 'session'.
_thread_context = threading.local()
# _thread_context = threading.local()


def _build_api_url(http://github.com/url,%20query):
Expand Down Expand Up @@ -510,10 +510,12 @@ def request_raw(
url, supplied_headers, method, params, files, request_id
)

if not hasattr(_thread_context, "session"):
_thread_context.session = _make_session()
session = openai.requestssession.get()
if not session:
session = _make_session()
openai.requestssession.set(session)
try:
result = _thread_context.session.request(
result = session.request(
method,
abs_url,
headers=headers,
Expand Down
57 changes: 57 additions & 0 deletions openai/easyaz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
import typing

if typing.TYPE_CHECKING:
from azure.core.credentials import TokenCredential

try:
from requests import PreparedRequest, Response, Session
from requests.adapters import HTTPAdapter
except ModuleNotFoundError:
print("You have to install the `requests` library (pip install requests) in order to use easyaz.requests")
exit(-1)

import azure.identity
import openai

log = logging.getLogger(__name__)

def init(endpoint: str, *, credential=None, api_version='2022-12-01'):
openai.api_type = 'azure_ad'
openai.api_key = 'dummy'
openai.api_base = endpoint
openai.api_version = api_version

if not credential:
credential = azure.identity.DefaultAzureCredential()

session = Session()
adapter = AzHttpAdapter(credential=credential, scopes=[ "https://cognitiveservices.azure.com/.default" ])
session.mount(endpoint, adapter)
openai.requestssession.set(session)


class AzHttpAdapter(HTTPAdapter):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to support more functionality (e.g. auxiliary header), do we want to extend the adapter? Or we want to make those stuff plug-ins so user can choose what they want?


def __init__(self, *, credential: "TokenCredential", scopes: list[str] | str, **kwargs: typing.Any):
super().__init__(**kwargs)
self.credential = credential
self.scopes = [ scopes ] if isinstance(scopes, str) else scopes
self.max_recurse = 1

def send(self, request: PreparedRequest, stream: bool = ..., timeout: None | float | tuple[float, float] | tuple[float, None] = ..., verify: bool | str = ..., cert: None | bytes | str | tuple[bytes | str, bytes | str] = ..., proxies: typing.Mapping[str, str] | None = ..., *, recurse:int=0) -> Response:
initial_response = super().send(request, stream, timeout, verify, cert, proxies)
if initial_response.status_code != 401 or recurse > self.max_recurse:
# Only do the auth dance if we are challenged...
return initial_response

log.info('Received 401 response from service - grabbing a token!')

# Drain response
initial_response.content

# Fill in a token
new_request = request.copy()
access_token = self.credential.get_token(*self.scopes) # TODO: sniff out claims from response

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed, we need caching.

new_request.headers['Authorization'] = 'Bearer ' + access_token.token
return self.send(new_request, stream, timeout, verify, cert, proxies, recurse=recurse + 1)