Skip to content

consider a custom exception for rai #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: v1_azure
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/openai/azure/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from ._sync_client import AzureOpenAIClient
from ._async_client import AsyncAzureOpenAIClient
from ._credential import TokenCredential
from ._exceptions import ContentPolicyError

__all__ = [
"AzureOpenAIClient",
"TokenCredential",
"AsyncAzureOpenAIClient",
"ContentPolicyError",
]
23 changes: 22 additions & 1 deletion src/openai/azure/_azuremodels.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,26 @@
from typing import TypedDict
from typing import Optional
from typing_extensions import TypedDict, Literal
from openai._models import BaseModel as BaseModel


class ChatExtensionConfiguration(TypedDict):
type: str
parameters: object

# TODO: just copying in from other PR
class ContentFilterResult(BaseModel):
severity: Literal["safe", "low", "medium", "high"]
filtered: bool


class Error(BaseModel):
code: str
message: str


class ContentFilterResults(BaseModel):
hate: Optional[ContentFilterResult]
self_harm: Optional[ContentFilterResult]
violence: Optional[ContentFilterResult]
sexual: Optional[ContentFilterResult]
error: Optional[Error]
17 changes: 17 additions & 0 deletions src/openai/azure/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations
import httpx
from openai import BadRequestError
from ._azuremodels import ContentFilterResults


class ContentPolicyError(BadRequestError):
code: str
message: str
content_filter_result: ContentFilterResults

def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None:
super().__init__(message=message, response=response, body=body)
self.code = body["error"]["code"]
self.message = body["error"]["message"]
self.error = body["error"]
self.content_filter_result = ContentFilterResults.construct(**body["error"]["innererror"]["content_filter_result"])
20 changes: 18 additions & 2 deletions src/openai/azure/_sync_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing_extensions import Literal, override
from typing import Any, Callable, cast, List, Mapping, Dict, Optional, overload, Union
import time
import json

import httpx

from openai import Client, OpenAIError
from openai import Client, OpenAIError, BadRequestError
from openai.types import ImagesResponse

# These are types used in the public API surface area that are not exported as public
Expand All @@ -21,6 +22,7 @@
# Azure specific types
from ._credential import TokenCredential
from ._azuremodels import ChatExtensionConfiguration
from ._exceptions import ContentPolicyError

TIMEOUT_SECS = 600

Expand Down Expand Up @@ -399,7 +401,21 @@ def _request(self, *, options: FinalRequestOptions, **kwargs: Any) -> Any:
options.url = f'openai/deployments/{model}/extensions' + options.url
else:
options.url = f'openai/deployments/{model}' + options.url
return super()._request(options=options, **kwargs)
try:
return super()._request(options=options, **kwargs)
except BadRequestError as err:
try:
body = json.loads(err.response.text)
except Exception:
raise err

if body.get('error') and body['error'].get('code') == 'content_filter':
raise ContentPolicyError(
message=err.message,
response=err.response,
body=body
)
raise err

# Internal azure specific "helper" methods
def _check_polling_response(self, response: httpx.Response, predicate: Callable[[httpx.Response], bool]) -> bool:
Expand Down