Skip to content

clean up Azure client #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: v1_azure
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/openai/azure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from ._async_client import AsyncAzureOpenAIClient
from ._credential import TokenCredential


__all__ = [
"AzureOpenAIClient",
"TokenCredential",
"AsyncAzureOpenAIClient",
]
]
117 changes: 37 additions & 80 deletions src/openai/azure/_async_client.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
from __future__ import annotations

from typing_extensions import Literal, override
from typing import Any, Callable, cast, List, Mapping, Dict, Optional, overload, Type, Union
import time
from typing import Any, cast, List, Dict, Optional, overload, Union, AsyncIterable

import httpx

from openai import AsyncClient, OpenAIError
from openai import AsyncClient
from openai.resources.chat import AsyncChat, AsyncCompletions
from openai.resources.completions import AsyncCompletions as AsyncCompletionsOperations
from openai.types import ImagesResponse

from openai.types.chat import ChatCompletionMessageParam, ChatCompletion, ChatCompletionChunk
from openai.types.chat.completion_create_params import FunctionCall, Function
from openai.types.completion import Completion

# These types are needed for correct typing of overrides
from openai._types import NotGiven, NOT_GIVEN, Headers, Query, Body, ResponseT
from openai._types import NotGiven, NOT_GIVEN, Headers, Query, Body

# These are types used in the public API surface area that are not exported as public
from openai._models import FinalRequestOptions
from openai._streaming import AsyncStream

# Azure specific types
from ._credential import TokenCredential, TokenAuth
from ._azuremodels import (
ChatExtensionConfiguration,
AzureChatCompletion,
AzureChatCompletion,
AzureChatCompletionChunk,
AzureCompletion,
)

TIMEOUT_SECS = 600
async def async_iterator(response, response_cls):
async for result in response:
yield response_cls.construct(**result.model_dump(mode="json"))


class AsyncAzureChat(AsyncChat):

Expand Down Expand Up @@ -210,7 +211,7 @@ async def create(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | None | NotGiven = NOT_GIVEN,
) -> AsyncStream[AzureChatCompletionChunk]:
) -> AsyncIterable[AzureChatCompletionChunk]:
"""
Creates a model response for the given chat conversation.

Expand Down Expand Up @@ -335,7 +336,7 @@ async def create(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | None | NotGiven = NOT_GIVEN,
) -> AzureChatCompletion | AsyncStream[AzureChatCompletionChunk]:
) -> AzureChatCompletion | AsyncIterable[AzureChatCompletionChunk]:
if data_sources:
if extra_body is None:
extra_body= {}
Expand All @@ -344,7 +345,7 @@ async def create(
"stream": True
} if stream else {}
response = cast(
Union[ChatCompletion, ChatCompletionChunk],
Union[ChatCompletion, AsyncStream[ChatCompletionChunk]],
await super().create(
messages=messages,
model=model,
Expand All @@ -367,11 +368,14 @@ async def create(
)
)
if isinstance(response, AsyncStream):
response._cast_to = AzureChatCompletionChunk # or rebuild the stream?
return async_iterator(
response=response,
response_cls=AzureChatCompletionChunk
)
else:
response_json = response.model_dump(mode="json")
response = AzureChatCompletion.construct(**response_json)
return response # type: ignore
return response # type: ignore


class AsyncAzureCompletions(AsyncCompletionsOperations):
Expand Down Expand Up @@ -569,7 +573,7 @@ async def create(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | None | NotGiven = NOT_GIVEN,
) -> AsyncStream[AzureCompletion]:
) -> AsyncIterable[AzureCompletion]:
"""
Creates a completion for the provided prompt and parameters.

Expand Down Expand Up @@ -723,7 +727,7 @@ async def create(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | None | NotGiven = NOT_GIVEN,
) -> AzureCompletion | AsyncStream[AzureCompletion]:
) -> AzureCompletion | AsyncIterable[AzureCompletion]:
stream_dict: Dict[str, Literal[True]] = { # TODO: pylance is upset if I pass through the parameter value. Overload + override combination is problematic
"stream": True
} if stream else {}
Expand Down Expand Up @@ -754,11 +758,14 @@ async def create(
)

if isinstance(response, AsyncStream):
response._cast_to = AzureCompletion
return async_iterator(
response=response,
response_cls=AzureCompletion
)
else:
response_json = response.model_dump(mode="json")
response = AzureCompletion.construct(**response_json)
return response # type: ignore
return response # type: ignore


class AsyncAzureOpenAIClient(AsyncClient):
Expand All @@ -777,14 +784,22 @@ def completions(self) -> AsyncAzureCompletions:
def completions(self, value: AsyncAzureCompletions) -> None:
self._completions = value

def __init__(self, *args: Any, credential: Optional["TokenCredential"] = None, api_version: str = '2023-09-01-preview', **kwargs: Any):
@overload
def __init__(self, *, base_url: str, api_key: str, api_version: str = '2023-09-01-preview', **kwargs: Any) -> None:
...

@overload
def __init__(self, *, base_url: str, credential: "AsyncTokenCredential", api_version: str = '2023-09-01-preview', **kwargs: Any) -> None:
...

def __init__(self, **kwargs: Any) -> None:
default_query = kwargs.get('default_query', {})
default_query.setdefault('api-version', api_version)
default_query.setdefault('api-version', kwargs.pop("api_version", '2023-09-01-preview'))
kwargs['default_query'] = default_query
self.credential = credential
if credential:
self.credential = kwargs.pop("credential", None)
if self.credential:
kwargs['api_key'] = 'Placeholder: AAD' # TODO: There is an assumption/validation there is always an API key.
super().__init__(*args, **kwargs)
super().__init__(**kwargs)
self._chat = AsyncAzureChat(self)

@property
Expand All @@ -795,61 +810,3 @@ def auth_headers(self) -> Dict[str, str]:
def custom_auth(self) -> httpx.Auth | None:
if self.credential:
return TokenAuth(self.credential)

def _check_polling_response(self, response: httpx.Response, predicate: Callable[[httpx.Response], bool]) -> bool:
if not predicate(response):
return False
error_data = response.json()['error']
message: str = cast(str, error_data.get('message', 'Operation failed'))
code = error_data.get('code')
raise OpenAIError(f'Error: {message} ({code})')

async def _poll(
self,
method: str,
url: str,
until: Callable[[httpx.Response], bool],
failed: Callable[[httpx.Response], bool],
interval: Optional[float] = None,
delay: Optional[float] = None,
) -> ImagesResponse:
if delay:
time.sleep(delay)

opts = FinalRequestOptions.construct(method=method, url=url)
response = await super().request(httpx.Response, opts)
self._check_polling_response(response, failed)
start_time = time.time()
while not until(response):
if time.time() - start_time > TIMEOUT_SECS:
raise Exception("Operation polling timed out.") # TODO: Fix up exception type.

time.sleep(interval or int(response.headers.get("retry-after")) or 10)
response = await super().request(httpx.Response, opts)
self._check_polling_response(response, failed)

response_json = response.json()
return ImagesResponse.construct(**response_json["result"])

# NOTE: We override the internal method because `@overrid`ing `@overload`ed methods and keeping typing happy is a pain. Most typing tools are lacking...
async def _request(self, cast_to: Type[ResponseT], options: FinalRequestOptions, **kwargs: Any) -> Any:
if options.url == "/images/generations":
options.url = "openai/images/generations:submit"
response = await super()._request(cast_to=cast_to, options=options, **kwargs)
model_extra = cast(Mapping[str, Any], getattr(response, 'model_extra')) or {}
operation_id = cast(str, model_extra['id'])
return await self._poll(
"get", f"openai/operations/images/{operation_id}",
until=lambda response: response.json()["status"] in ["succeeded"],
failed=lambda response: response.json()["status"] in ["failed"],
)
if isinstance(options.json_data, Mapping):
model = cast(str, options.json_data["model"])
if not options.url.startswith(f'openai/deployments/{model}'):
if options.extra_json and options.extra_json.get("dataSources"):
options.url = f'openai/deployments/{model}/extensions' + options.url
else:
options.url = f'openai/deployments/{model}' + options.url
if options.url.startswith(("/models", "/fine_tuning", "/files", "/fine-tunes")):
options.url = f"openai{options.url}"
return await super()._request(cast_to=cast_to, options=options, **kwargs)
Loading