Skip to content

Client abstraction over API resources (prototype) #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
6922ae8
Client abstraction over API resources (prototype)
johanste Jun 9, 2023
e1c33e6
Added some async endpoints
johanste Jun 12, 2023
ac61e45
Added image APIs
johanste Jun 13, 2023
b894768
Allow the use of `model` as a synonym for `deployment_id` when using …
johanste Jun 13, 2023
b4fdcd6
Formatteding client.py (black)
johanste Jun 13, 2023
a4f45ef
Added basic tests, more validation for client
johanste Jun 14, 2023
49c7c8b
Update openai/client.py
johanste Jun 15, 2023
bd4960c
Update openai/client.py
johanste Jun 15, 2023
91e9b58
Review feedback
johanste Jun 17, 2023
f992de1
Merge branch 'client' of github.com:johanste/openai-python into client
johanste Jun 17, 2023
7a31602
Fix tests
johanste Jun 17, 2023
e3b59de
adding docstrings initial
kristapratico Jun 20, 2023
24dad1d
Add live tests (#4)
kristapratico Jun 22, 2023
8bb7d61
Add edit, audio, and moderation APIs to client abstraction (#5)
kristapratico Jun 22, 2023
214242c
Merge branch 'client' into docstrings
kristapratico Jun 22, 2023
917b97e
Better test fixtures
johanste Jun 23, 2023
e1edec4
more docstrings
kristapratico Jun 28, 2023
48adec5
docstrings + type hints + remove prompt edits (deprecated)
kristapratico Jul 19, 2023
e82ed27
fix find/replace mistake
kristapratico Jul 19, 2023
0cfaaff
add missing tests + adjust moderation kwargs passed
kristapratico Jul 24, 2023
d9b43f5
Merge branch 'openai:main' into client
johanste Jul 25, 2023
5809b39
Merge pull request #6 from johanste/docstrings
kristapratico Jul 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Client abstraction over API resources (prototype)
  • Loading branch information
johanste committed Jun 9, 2023
commit 6922ae83d04545ccd0b6b285f41e908031dc6743
137 changes: 137 additions & 0 deletions openai/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import logging
import typing

import openai

LATEST_AZURE_API_VERSION = "2023-05-15"


class AzureTokenAuth:
def __init__(self, credential=None):
if not credential:
try:
import azure.identity
except ImportError:
raise Exception(
"You have to install the azure-identity package in order to use AzureTokenAuth"
)
credential = azure.identity.DefaultAzureCredential()
self._credential = credential
self._cached_token = None

def get_token(self) -> str:
if self._cached_token is None:
self._cached_token = self._credential.get_token(
"https://cognitiveservices.azure.com/.default"
)
return self._cached_token.token


class ApiKeyAuth:
def __init__(self, key: str = ""):
self.key = key or openai.api_key

def get_token(self) -> str:
return self.key


Backends = typing.Literal["azure", "openai", ""]


class OpenAIClient:
def __init__(
self,
*,
api_base: str = "",
auth: typing.Union[str, ApiKeyAuth, AzureTokenAuth] = "",
api_version: str = "",
backend: Backends = "",
):
self.api_base = api_base or openai.api_base
if auth == "azuredefault":
self.auth = AzureTokenAuth()
elif isinstance(auth, str):
self.auth = ApiKeyAuth(auth or openai.api_key)
else:
self.auth = auth
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if we should just do validation here if it's not one of the supported auth types. Otherwise we're going to get an AttributeError when we do the self.auth.get_token() call on L125.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could. I think we are side-stepping the biggest issue (passing in a str for api key based auth), but I wouldn't mind adding more validation here.

I was planning to make a better diagnostics story (adding the __repr__ methods started it) that would make it easier to see how your client was configured - as well as how to deal with OAI or AOAI specific endpoints.


# Pick up api type from parameter or environment
backend = backend or (
"azure" if openai.api_type in ("azure", "azure_ad") else "openai"
)

self.backend = backend

if backend == "azure":
self.api_version = (
api_version or openai.api_version or LATEST_AZURE_API_VERSION
)
if isinstance(self.auth, AzureTokenAuth):
self.api_type = "azure_ad"
else:
self.api_type = "azure"
elif backend in ("openai", ""):
self.api_type = "open_ai"
self.api_version = api_version or openai.api_version
else:
raise ValueError(
f'Unknown `backend` {backend} - expected one of "azure" or "openai"'
)

def _populate_args(self, kwargs: typing.Dict[str, typing.Any], **overrides) -> None:
backend = self.backend

kwargs.setdefault("api_base", self.api_base or openai.api_base)
kwargs.setdefault("api_key", self.auth.get_token())
kwargs.setdefault("api_type", self.api_type)
if self.api_version:
kwargs.setdefault("api_version", self.api_version)

for key, val in overrides.items():
kwargs.setdefault(key, val)
if kwargs[key] != val:
raise ValueError(f"No parameter named `{key}`")

def completion(self, prompt: str, **kwargs) -> openai.Completion:
self._populate_args(kwargs, prompt=prompt, stream=False)
return typing.cast(openai.Completion, openai.Completion.create(**kwargs))

def iter_completion(
self, prompt: str, **kwargs
) -> typing.Iterable[openai.Completion]:
self._populate_args(kwargs, prompt=prompt, stream=True)
return typing.cast(
typing.Iterable[openai.Completion], openai.Completion.create(**kwargs)
)

def chatcompletion(self, messages, **kwargs) -> openai.ChatCompletion:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since we snake_case the image API names, this kind of feels like it should follow that as well

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I do think we should review the names across the board. I used the convention of iter_ for server sent event streaming APIs and _<variation name> for "child"/custom actions beyond the base generation API. But that was just the first thing that popped into my head.

self._populate_args(kwargs, messages=messages, stream=False)
return typing.cast(
openai.ChatCompletion, openai.ChatCompletion.create(**kwargs)
)

def iter_chatcompletion(
self, messages, **kwargs
) -> typing.Iterable[openai.ChatCompletion]:
self._populate_args(kwargs, messages=messages, stream=True)
return typing.cast(
typing.Iterable[openai.ChatCompletion],
openai.ChatCompletion.create(**kwargs),
)

def embeddings(self, input, **kwargs):
self._populate_args(kwargs, input=input)
return typing.cast(openai.Embedding, openai.Embedding.create(**kwargs))


if __name__ == "__main__":
client = OpenAIClient(
api_base="https://achand-openai-0.openai.azure.com/",
auth="azuredefault",
backend="azure",
)
print(client.completion("what is up, my friend?", deployment_id="chatgpt"))
print(client.embeddings("What, or what is this?", deployment_id="arch")) # Doesn't work 'cause it is the wrong model...
oaiclient = OpenAIClient()
print(oaiclient.completion("what is up, my friend?", model="text-davinci-003"))
print(oaiclient.embeddings("What are embeddings?", model="text-embedding-ada-002"))