-
Notifications
You must be signed in to change notification settings - Fork 1
Client abstraction over API resources (prototype) #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
6922ae8
e1c33e6
ac61e45
b894768
b4fdcd6
a4f45ef
49c7c8b
bd4960c
91e9b58
f992de1
7a31602
e3b59de
24dad1d
8bb7d61
214242c
917b97e
e1edec4
48adec5
e82ed27
0cfaaff
d9b43f5
5809b39
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import logging | ||
import typing | ||
|
||
import openai | ||
|
||
LATEST_AZURE_API_VERSION = "2023-05-15" | ||
|
||
|
||
class AzureTokenAuth: | ||
def __init__(self, credential=None): | ||
if not credential: | ||
try: | ||
import azure.identity | ||
except ImportError: | ||
raise Exception( | ||
"You have to install the azure-identity package in order to use AzureTokenAuth" | ||
) | ||
credential = azure.identity.DefaultAzureCredential() | ||
self._credential = credential | ||
self._cached_token = None | ||
|
||
def get_token(self) -> str: | ||
if self._cached_token is None: | ||
self._cached_token = self._credential.get_token( | ||
"https://cognitiveservices.azure.com/.default" | ||
) | ||
return self._cached_token.token | ||
|
||
|
||
class ApiKeyAuth: | ||
kristapratico marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, key: str = ""): | ||
self.key = key or openai.api_key | ||
|
||
def get_token(self) -> str: | ||
return self.key | ||
|
||
|
||
Backends = typing.Literal["azure", "openai", ""] | ||
johanste marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class OpenAIClient: | ||
def __init__( | ||
kristapratico marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, | ||
*, | ||
api_base: str = "", | ||
auth: typing.Union[str, ApiKeyAuth, AzureTokenAuth] = "", | ||
api_version: str = "", | ||
backend: Backends = "", | ||
): | ||
self.api_base = api_base or openai.api_base | ||
if auth == "azuredefault": | ||
self.auth = AzureTokenAuth() | ||
elif isinstance(auth, str): | ||
self.auth = ApiKeyAuth(auth or openai.api_key) | ||
else: | ||
self.auth = auth | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wonder if we should just do validation here if it's not one of the supported auth types. Otherwise we're going to get an AttributeError when we do the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could. I think we are side-stepping the biggest issue (passing in a I was planning to make a better diagnostics story (adding the |
||
|
||
# Pick up api type from parameter or environment | ||
backend = backend or ( | ||
"azure" if openai.api_type in ("azure", "azure_ad") else "openai" | ||
) | ||
|
||
self.backend = backend | ||
kristapratico marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if backend == "azure": | ||
self.api_version = ( | ||
api_version or openai.api_version or LATEST_AZURE_API_VERSION | ||
) | ||
if isinstance(self.auth, AzureTokenAuth): | ||
self.api_type = "azure_ad" | ||
else: | ||
self.api_type = "azure" | ||
elif backend in ("openai", ""): | ||
self.api_type = "open_ai" | ||
self.api_version = api_version or openai.api_version | ||
else: | ||
raise ValueError( | ||
f'Unknown `backend` {backend} - expected one of "azure" or "openai"' | ||
) | ||
|
||
def _populate_args(self, kwargs: typing.Dict[str, typing.Any], **overrides) -> None: | ||
backend = self.backend | ||
|
||
kwargs.setdefault("api_base", self.api_base or openai.api_base) | ||
kwargs.setdefault("api_key", self.auth.get_token()) | ||
kwargs.setdefault("api_type", self.api_type) | ||
if self.api_version: | ||
kwargs.setdefault("api_version", self.api_version) | ||
|
||
for key, val in overrides.items(): | ||
kwargs.setdefault(key, val) | ||
if kwargs[key] != val: | ||
raise ValueError(f"No parameter named `{key}`") | ||
|
||
def completion(self, prompt: str, **kwargs) -> openai.Completion: | ||
self._populate_args(kwargs, prompt=prompt, stream=False) | ||
return typing.cast(openai.Completion, openai.Completion.create(**kwargs)) | ||
|
||
def iter_completion( | ||
self, prompt: str, **kwargs | ||
) -> typing.Iterable[openai.Completion]: | ||
self._populate_args(kwargs, prompt=prompt, stream=True) | ||
return typing.cast( | ||
typing.Iterable[openai.Completion], openai.Completion.create(**kwargs) | ||
) | ||
|
||
def chatcompletion(self, messages, **kwargs) -> openai.ChatCompletion: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: since we snake_case the image API names, this kind of feels like it should follow that as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I do think we should review the names across the board. I used the convention of |
||
self._populate_args(kwargs, messages=messages, stream=False) | ||
return typing.cast( | ||
openai.ChatCompletion, openai.ChatCompletion.create(**kwargs) | ||
) | ||
|
||
def iter_chatcompletion( | ||
self, messages, **kwargs | ||
) -> typing.Iterable[openai.ChatCompletion]: | ||
self._populate_args(kwargs, messages=messages, stream=True) | ||
return typing.cast( | ||
typing.Iterable[openai.ChatCompletion], | ||
openai.ChatCompletion.create(**kwargs), | ||
) | ||
|
||
def embeddings(self, input, **kwargs): | ||
self._populate_args(kwargs, input=input) | ||
return typing.cast(openai.Embedding, openai.Embedding.create(**kwargs)) | ||
|
||
|
||
if __name__ == "__main__": | ||
client = OpenAIClient( | ||
api_base="https://achand-openai-0.openai.azure.com/", | ||
auth="azuredefault", | ||
backend="azure", | ||
) | ||
print(client.completion("what is up, my friend?", deployment_id="chatgpt")) | ||
print(client.embeddings("What, or what is this?", deployment_id="arch")) # Doesn't work 'cause it is the wrong model... | ||
oaiclient = OpenAIClient() | ||
print(oaiclient.completion("what is up, my friend?", model="text-davinci-003")) | ||
print(oaiclient.embeddings("What are embeddings?", model="text-embedding-ada-002")) |
Uh oh!
There was an error while loading. Please reload this page.