Skip to content

Initial prototype v1 openai support for azure #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: v1
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
"AsyncStream",
"OpenAI",
"AsyncOpenAI",
"file_from_path",
"file_from_path"
]

from .version import VERSION as VERSION
Expand Down
14 changes: 10 additions & 4 deletions src/openai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@

class OpenAI(SyncAPIClient):
completions: resources.Completions
chat: resources.Chat

@property
def chat(self) -> resources.chat.Chat:
return self._chat
# chat: resources.chat.Chat
edits: resources.Edits
embeddings: resources.Embeddings
files: resources.Files
Expand Down Expand Up @@ -122,7 +126,7 @@ def __init__(
self._default_stream_cls = Stream

self.completions = resources.Completions(self)
self.chat = resources.Chat(self)
self._chat = resources.Chat(self)
self.edits = resources.Edits(self)
self.embeddings = resources.Embeddings(self)
self.files = resources.Files(self)
Expand Down Expand Up @@ -244,7 +248,9 @@ def _make_status_error(

class AsyncOpenAI(AsyncAPIClient):
completions: resources.AsyncCompletions
chat: resources.AsyncChat
@property
def chat(self) -> resources.AsyncChat:
return self._chat
edits: resources.AsyncEdits
embeddings: resources.AsyncEmbeddings
files: resources.AsyncFiles
Expand Down Expand Up @@ -320,7 +326,7 @@ def __init__(
self._default_stream_cls = AsyncStream

self.completions = resources.AsyncCompletions(self)
self.chat = resources.AsyncChat(self)
self._chat = resources.AsyncChat(self)
self.edits = resources.AsyncEdits(self)
self.embeddings = resources.AsyncEmbeddings(self)
self.files = resources.AsyncFiles(self)
Expand Down
9 changes: 9 additions & 0 deletions src/openai/azure/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ._sync_client import AzureOpenAIClient
from ._async_client import AsyncAzureOpenAIClient
from ._credential import TokenCredential

__all__ = [
"AzureOpenAIClient",
"TokenCredential",
"AsyncAzureOpenAIClient",
]
855 changes: 855 additions & 0 deletions src/openai/azure/_async_client.py

Large diffs are not rendered by default.

82 changes: 82 additions & 0 deletions src/openai/azure/_azuremodels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import List, Optional
from typing_extensions import TypedDict, Literal
from openai._models import BaseModel as BaseModel

from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice as ChatChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta, Choice as ChatChoiceDelta
from openai.types.completion import Completion
from openai.types.completion_choice import CompletionChoice


AzureChatCompletionRole = Literal["system", "user", "assistant", "function", "tool"]


class ChatExtensionConfiguration(TypedDict):
type: Literal["AzureCognitiveSearch"]
parameters: object


class ContentFilterResult(BaseModel):
severity: Literal["safe", "low", "medium", "high"]
filtered: bool


class Error(BaseModel):
code: str
message: str


class ContentFilterResults(BaseModel):
hate: Optional[ContentFilterResult]
self_harm: Optional[ContentFilterResult]
violence: Optional[ContentFilterResult]
sexual: Optional[ContentFilterResult]
error: Optional[Error]


class PromptFilterResult(BaseModel):
prompt_index: int
content_filter_results: Optional[ContentFilterResults]


class AzureChatExtensionsMessageContext(BaseModel):
messages: Optional[List[ChatCompletionMessage]]


class AzureChatCompletionMessage(ChatCompletionMessage):
context: Optional[AzureChatExtensionsMessageContext]
role: AzureChatCompletionRole # type: ignore


class AzureChatCompletionChoice(ChatChoice):
content_filter_results: Optional[ContentFilterResults]
message: AzureChatCompletionMessage # type: ignore


class AzureChatCompletion(ChatCompletion):
choices: List[AzureChatCompletionChoice] # type: ignore
prompt_filter_results: Optional[List[PromptFilterResult]]


class AzureChoiceDelta(ChoiceDelta):
context: Optional[AzureChatExtensionsMessageContext]


class AzureChatCompletionChoiceDelta(ChatChoiceDelta):
delta: AzureChoiceDelta # type: ignore
content_filter_results: Optional[ContentFilterResults]


class AzureChatCompletionChunk(ChatCompletionChunk):
choices: List[AzureChatCompletionChoiceDelta] # type: ignore
prompt_filter_results: Optional[List[PromptFilterResult]]


class AzureCompletionChoice(CompletionChoice):
content_filter_results: Optional[ContentFilterResults]


class AzureCompletion(Completion):
choices: List[AzureCompletionChoice] # type: ignore
prompt_filter_results: Optional[List[PromptFilterResult]]
46 changes: 46 additions & 0 deletions src/openai/azure/_credential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import AsyncGenerator, Generator, Any
import time
import asyncio
import httpx


class TokenCredential:
"""Placeholder/example token credential class

A real implementation would be compatible with e.g. azure-identity and also should be easily
adaptible to other token credential implementations.
"""
def __init__(self):
import azure.identity
self._credential = azure.identity.DefaultAzureCredential()

def get_token(self):
return self._credential.get_token('https://cognitiveservices.azure.com/.default').token


class TokenAuth(httpx.Auth):
def __init__(self, credential: "TokenCredential") -> None:
self._credential = credential
self._async_lock = asyncio.Lock()
self.cached_token = None

def sync_get_token(self) -> str:
if not self.cached_token or self.cached_token.expires_on - time.time() < 300:
return self._credential.get_token("https://cognitiveservices.azure.com/.default").token
return self.cached_token.token

def sync_auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, Any, Any]:
token = self.sync_get_token()
request.headers["Authorization"] = f"Bearer {token}"
yield request

async def async_get_token(self) -> str:
async with self._async_lock:
if not self.cached_token or self.cached_token.expires_on - time.time() < 300:
return (await self._credential.get_token("https://cognitiveservices.azure.com/.default")).token
return self.cached_token.token

async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, Any]:
token = await self.async_get_token()
request.headers["Authorization"] = f"Bearer {token}"
yield request
Loading