diff --git a/src/openai/azure/_async_client.py b/src/openai/azure/_async_client.py index 77708fa74f..feb271a0e3 100644 --- a/src/openai/azure/_async_client.py +++ b/src/openai/azure/_async_client.py @@ -1,6 +1,7 @@ from typing_extensions import Literal, override from typing import Any, Callable, cast, List, Mapping, Dict, Optional, overload, Type, Union import time +import json import httpx @@ -378,6 +379,25 @@ def auth_headers(self) -> Dict[str, str]: return { 'Authorization': f'Bearer {self.credential.get_token()}'} return {"api-key": self.api_key} + def _prepare_request(self, request: httpx.Request) -> None: + # TODO: need confirmation that it is okay to override this + url = request.url + if url.path.startswith("/audio"): + model_name = request.stream.fields[0].value # is this a robust way to extract model name? + request.url = url.copy_with(path=f"/openai/deployments/{model_name}{url.path}") + return + + try: + content = json.loads(request.content) + except json.JSONDecodeError: + return + + if content.get("dataSources"): + request.url = url.copy_with(path=f"/openai/deployments/{content['model']}/extensions{url.path}") + elif request.url.path == "/images/generations": + request.url = url.copy_with(path="/openai/images/generations:submit") + elif content.get("model"): + request.url = url.copy_with(path=f"/openai/deployments/{content['model']}{url.path}") def _check_polling_response(self, response: httpx.Response, predicate: Callable[[httpx.Response], bool]) -> bool: if not predicate(response): @@ -416,21 +436,14 @@ async def _poll( # NOTE: We override the internal method because `@overrid`ing `@overload`ed methods and keeping typing happy is a pain. Most typing tools are lacking... async def _request(self, cast_to: Type[ResponseT], options: FinalRequestOptions, **kwargs: Any) -> Any: - if options.url == "/images/generations": - options.url = "openai/images/generations:submit" - response = await super()._request(cast_to=cast_to, options=options, **kwargs) + response = await super()._request(cast_to=cast_to, options=options, **kwargs) + if isinstance(response, ImagesResponse): model_extra = cast(Mapping[str, Any], getattr(response, 'model_extra')) or {} - operation_id = cast(str, model_extra['id']) - return await self._poll( - "get", f"openai/operations/images/{operation_id}", - until=lambda response: response.json()["status"] in ["succeeded"], - failed=lambda response: response.json()["status"] in ["failed"], - ) - if isinstance(options.json_data, Mapping): - model = cast(str, options.json_data["model"]) - if not options.url.startswith(f'openai/deployments/{model}'): - if options.extra_json and options.extra_json.get("dataSources"): - options.url = f'openai/deployments/{model}/extensions' + options.url - else: - options.url = f'openai/deployments/{model}' + options.url + if model_extra.get("id"): + operation_id = cast(str, model_extra['id']) + return await self._poll( + "get", f"openai/operations/images/{operation_id}", + until=lambda response: response.json()["status"] in ["succeeded"], + failed=lambda response: response.json()["status"] in ["failed"], + ) return await super()._request(cast_to=cast_to, options=options, **kwargs) \ No newline at end of file diff --git a/src/openai/azure/_sync_client.py b/src/openai/azure/_sync_client.py index ba7faccf20..1e3f026b91 100644 --- a/src/openai/azure/_sync_client.py +++ b/src/openai/azure/_sync_client.py @@ -1,6 +1,7 @@ from typing_extensions import Literal, override from typing import Any, Callable, cast, List, Mapping, Dict, Optional, overload, Union import time +import json import httpx @@ -380,26 +381,40 @@ def auth_headers(self) -> Dict[str, str]: return { 'Authorization': f'Bearer {self.credential.get_token()}'} return {"api-key": self.api_key} + def _prepare_request(self, request: httpx.Request) -> None: + # TODO: need confirmation that it is okay to override this + url = request.url + if url.path.startswith("/audio"): + model_name = request.stream.fields[0].value # is this a robust way to extract model name? + request.url = url.copy_with(path=f"/openai/deployments/{model_name}{url.path}") + return + + try: + content = json.loads(request.content) + except json.JSONDecodeError: + return + + if content.get("dataSources"): + request.url = url.copy_with(path=f"/openai/deployments/{content['model']}/extensions{url.path}") + elif request.url.path == "/images/generations": + request.url = url.copy_with(path="/openai/images/generations:submit") + elif content.get("model"): + request.url = url.copy_with(path=f"/openai/deployments/{content['model']}{url.path}") + # NOTE: We override the internal method because `@overrid`ing `@overload`ed methods and keeping typing happy is a pain. Most typing tools are lacking... def _request(self, *, options: FinalRequestOptions, **kwargs: Any) -> Any: - if options.url == "/images/generations": - options.url = "openai/images/generations:submit" - response = super()._request(options=options, **kwargs) + response = super()._request(options=options, **kwargs) + if isinstance(response, ImagesResponse): model_extra = cast(Mapping[str, Any], getattr(response, 'model_extra')) or {} - operation_id = cast(str, model_extra['id']) - return self._poll( - "get", f"openai/operations/images/{operation_id}", - until=lambda response: response.json()["status"] in ["succeeded"], - failed=lambda response: response.json()["status"] in ["failed"], - ) - if isinstance(options.json_data, Mapping): - model = cast(str, options.json_data["model"]) - if not options.url.startswith(f'openai/deployments/{model}'): - if options.extra_json and options.extra_json.get("dataSources"): - options.url = f'openai/deployments/{model}/extensions' + options.url - else: - options.url = f'openai/deployments/{model}' + options.url - return super()._request(options=options, **kwargs) + if model_extra.get("id"): + operation_id = cast(str, model_extra['id']) + return self._poll( + "get", f"openai/operations/images/{operation_id}", + until=lambda response: response.json()["status"] in ["succeeded"], + failed=lambda response: response.json()["status"] in ["failed"], + ) + + return response # Internal azure specific "helper" methods def _check_polling_response(self, response: httpx.Response, predicate: Callable[[httpx.Response], bool]) -> bool: