From 50c184549bade8a83f7190fd437d0b0c37f6f264 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Thu, 5 Oct 2023 19:00:44 -0700 Subject: [PATCH 1/2] prepare_request option --- src/openai/azure/_async_client.py | 40 +++++++++++++++++----------- src/openai/azure/_sync_client.py | 44 +++++++++++++++++++------------ 2 files changed, 51 insertions(+), 33 deletions(-) diff --git a/src/openai/azure/_async_client.py b/src/openai/azure/_async_client.py index 77708fa74f..92e35bf764 100644 --- a/src/openai/azure/_async_client.py +++ b/src/openai/azure/_async_client.py @@ -1,6 +1,7 @@ from typing_extensions import Literal, override from typing import Any, Callable, cast, List, Mapping, Dict, Optional, overload, Type, Union import time +import json import httpx @@ -378,6 +379,20 @@ def auth_headers(self) -> Dict[str, str]: return { 'Authorization': f'Bearer {self.credential.get_token()}'} return {"api-key": self.api_key} + def _prepare_request(self, request: httpx.Request) -> None: + # TODO: need confirmation that it is okay to override this + # TODO: url building feels hacky - do better + try: + content = json.loads(request.content) + except json.JSONDecodeError: + return + url = request.url + if content.get("dataSources"): + request.url = httpx.URL(f"{url.scheme}://" + url.host + f"/openai/deployments/{content['model']}/extensions" + url.path + f"?{url.query.decode()}") + elif request.url.path == "/images/generations": + request.url = httpx.URL(f"{url.scheme}://" + url.host + "/openai/images/generations:submit?" + f"{url.query.decode()}") + elif content.get("model"): + request.url = httpx.URL(f"{url.scheme}://" + url.host + f"/openai/deployments/{content['model']}" + url.path + f"?{url.query.decode()}") def _check_polling_response(self, response: httpx.Response, predicate: Callable[[httpx.Response], bool]) -> bool: if not predicate(response): @@ -416,21 +431,14 @@ async def _poll( # NOTE: We override the internal method because `@overrid`ing `@overload`ed methods and keeping typing happy is a pain. Most typing tools are lacking... async def _request(self, cast_to: Type[ResponseT], options: FinalRequestOptions, **kwargs: Any) -> Any: - if options.url == "/images/generations": - options.url = "openai/images/generations:submit" - response = await super()._request(cast_to=cast_to, options=options, **kwargs) + response = await super()._request(cast_to=cast_to, options=options, **kwargs) + if isinstance(response, ImagesResponse): model_extra = cast(Mapping[str, Any], getattr(response, 'model_extra')) or {} - operation_id = cast(str, model_extra['id']) - return await self._poll( - "get", f"openai/operations/images/{operation_id}", - until=lambda response: response.json()["status"] in ["succeeded"], - failed=lambda response: response.json()["status"] in ["failed"], - ) - if isinstance(options.json_data, Mapping): - model = cast(str, options.json_data["model"]) - if not options.url.startswith(f'openai/deployments/{model}'): - if options.extra_json and options.extra_json.get("dataSources"): - options.url = f'openai/deployments/{model}/extensions' + options.url - else: - options.url = f'openai/deployments/{model}' + options.url + if model_extra.get("id"): + operation_id = cast(str, model_extra['id']) + return await self._poll( + "get", f"openai/operations/images/{operation_id}", + until=lambda response: response.json()["status"] in ["succeeded"], + failed=lambda response: response.json()["status"] in ["failed"], + ) return await super()._request(cast_to=cast_to, options=options, **kwargs) \ No newline at end of file diff --git a/src/openai/azure/_sync_client.py b/src/openai/azure/_sync_client.py index ba7faccf20..4c9649d153 100644 --- a/src/openai/azure/_sync_client.py +++ b/src/openai/azure/_sync_client.py @@ -1,6 +1,7 @@ from typing_extensions import Literal, override from typing import Any, Callable, cast, List, Mapping, Dict, Optional, overload, Union import time +import json import httpx @@ -380,26 +381,35 @@ def auth_headers(self) -> Dict[str, str]: return { 'Authorization': f'Bearer {self.credential.get_token()}'} return {"api-key": self.api_key} + def _prepare_request(self, request: httpx.Request) -> None: + # TODO: need confirmation that it is okay to override this + # TODO: url building feels hacky - do better + try: + content = json.loads(request.content) + except json.JSONDecodeError: + return + url = request.url + # TODO: url building feels hacky - do better + if content.get("dataSources"): + request.url = httpx.URL(f"{url.scheme}://" + url.host + f"/openai/deployments/{content['model']}/extensions" + url.path + f"?{url.query.decode()}") + elif request.url.path == "/images/generations": + request.url = httpx.URL(f"{url.scheme}://" + url.host + "/openai/images/generations:submit?" + f"{url.query.decode()}") + elif content.get("model"): + request.url = httpx.URL(f"{url.scheme}://" + url.host + f"/openai/deployments/{content['model']}" + url.path + f"?{url.query.decode()}") + # NOTE: We override the internal method because `@overrid`ing `@overload`ed methods and keeping typing happy is a pain. Most typing tools are lacking... def _request(self, *, options: FinalRequestOptions, **kwargs: Any) -> Any: - if options.url == "/images/generations": - options.url = "openai/images/generations:submit" - response = super()._request(options=options, **kwargs) + response = super()._request(options=options, **kwargs) + if isinstance(response, ImagesResponse): model_extra = cast(Mapping[str, Any], getattr(response, 'model_extra')) or {} - operation_id = cast(str, model_extra['id']) - return self._poll( - "get", f"openai/operations/images/{operation_id}", - until=lambda response: response.json()["status"] in ["succeeded"], - failed=lambda response: response.json()["status"] in ["failed"], - ) - if isinstance(options.json_data, Mapping): - model = cast(str, options.json_data["model"]) - if not options.url.startswith(f'openai/deployments/{model}'): - if options.extra_json and options.extra_json.get("dataSources"): - options.url = f'openai/deployments/{model}/extensions' + options.url - else: - options.url = f'openai/deployments/{model}' + options.url - return super()._request(options=options, **kwargs) + if model_extra.get("id"): + operation_id = cast(str, model_extra['id']) + return self._poll( + "get", f"openai/operations/images/{operation_id}", + until=lambda response: response.json()["status"] in ["succeeded"], + failed=lambda response: response.json()["status"] in ["failed"], + ) + return response # Internal azure specific "helper" methods def _check_polling_response(self, response: httpx.Response, predicate: Callable[[httpx.Response], bool]) -> bool: From 2be5ce3f0ab8e6c70c7b26e0d2aa6c6d58dc4027 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Fri, 6 Oct 2023 13:42:34 -0700 Subject: [PATCH 2/2] use copy_with, workaround for audio --- src/openai/azure/_async_client.py | 15 ++++++++++----- src/openai/azure/_sync_client.py | 17 +++++++++++------ 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/src/openai/azure/_async_client.py b/src/openai/azure/_async_client.py index 92e35bf764..feb271a0e3 100644 --- a/src/openai/azure/_async_client.py +++ b/src/openai/azure/_async_client.py @@ -381,18 +381,23 @@ def auth_headers(self) -> Dict[str, str]: def _prepare_request(self, request: httpx.Request) -> None: # TODO: need confirmation that it is okay to override this - # TODO: url building feels hacky - do better + url = request.url + if url.path.startswith("/audio"): + model_name = request.stream.fields[0].value # is this a robust way to extract model name? + request.url = url.copy_with(path=f"/openai/deployments/{model_name}{url.path}") + return + try: content = json.loads(request.content) except json.JSONDecodeError: return - url = request.url + if content.get("dataSources"): - request.url = httpx.URL(f"{url.scheme}://" + url.host + f"/openai/deployments/{content['model']}/extensions" + url.path + f"?{url.query.decode()}") + request.url = url.copy_with(path=f"/openai/deployments/{content['model']}/extensions{url.path}") elif request.url.path == "/images/generations": - request.url = httpx.URL(f"{url.scheme}://" + url.host + "/openai/images/generations:submit?" + f"{url.query.decode()}") + request.url = url.copy_with(path="/openai/images/generations:submit") elif content.get("model"): - request.url = httpx.URL(f"{url.scheme}://" + url.host + f"/openai/deployments/{content['model']}" + url.path + f"?{url.query.decode()}") + request.url = url.copy_with(path=f"/openai/deployments/{content['model']}{url.path}") def _check_polling_response(self, response: httpx.Response, predicate: Callable[[httpx.Response], bool]) -> bool: if not predicate(response): diff --git a/src/openai/azure/_sync_client.py b/src/openai/azure/_sync_client.py index 4c9649d153..1e3f026b91 100644 --- a/src/openai/azure/_sync_client.py +++ b/src/openai/azure/_sync_client.py @@ -383,19 +383,23 @@ def auth_headers(self) -> Dict[str, str]: def _prepare_request(self, request: httpx.Request) -> None: # TODO: need confirmation that it is okay to override this - # TODO: url building feels hacky - do better + url = request.url + if url.path.startswith("/audio"): + model_name = request.stream.fields[0].value # is this a robust way to extract model name? + request.url = url.copy_with(path=f"/openai/deployments/{model_name}{url.path}") + return + try: content = json.loads(request.content) except json.JSONDecodeError: return - url = request.url - # TODO: url building feels hacky - do better + if content.get("dataSources"): - request.url = httpx.URL(f"{url.scheme}://" + url.host + f"/openai/deployments/{content['model']}/extensions" + url.path + f"?{url.query.decode()}") + request.url = url.copy_with(path=f"/openai/deployments/{content['model']}/extensions{url.path}") elif request.url.path == "/images/generations": - request.url = httpx.URL(f"{url.scheme}://" + url.host + "/openai/images/generations:submit?" + f"{url.query.decode()}") + request.url = url.copy_with(path="/openai/images/generations:submit") elif content.get("model"): - request.url = httpx.URL(f"{url.scheme}://" + url.host + f"/openai/deployments/{content['model']}" + url.path + f"?{url.query.decode()}") + request.url = url.copy_with(path=f"/openai/deployments/{content['model']}{url.path}") # NOTE: We override the internal method because `@overrid`ing `@overload`ed methods and keeping typing happy is a pain. Most typing tools are lacking... def _request(self, *, options: FinalRequestOptions, **kwargs: Any) -> Any: @@ -409,6 +413,7 @@ def _request(self, *, options: FinalRequestOptions, **kwargs: Any) -> Any: until=lambda response: response.json()["status"] in ["succeeded"], failed=lambda response: response.json()["status"] in ["failed"], ) + return response # Internal azure specific "helper" methods