Skip to content

Commit c65972e

Browse files
committed
Add retry logic for OpenAI chat completions
1 parent 7a91d2a commit c65972e

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

src/fastapi_app/rag_advanced.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
import copy
2+
import logging
23
import pathlib
34
from collections.abc import AsyncGenerator
4-
from typing import (
5-
Any,
6-
)
5+
from typing import Any
76

87
from openai import AsyncOpenAI
9-
from openai.types.chat import (
10-
ChatCompletion,
11-
)
8+
from openai.types.chat import ChatCompletion
129
from openai_messages_token_helper import get_token_limit
10+
from tenacity import before_sleep_log, retry, stop_after_attempt, wait_random_exponential
1311

1412
from .api_models import ThoughtStep
1513
from .postgres_searcher import PostgresSearcher
@@ -20,6 +18,9 @@
2018
handle_specify_package_function_call,
2119
)
2220

21+
# Configure logging
22+
logging.basicConfig(level=logging.INFO)
23+
logger = logging.getLogger(__name__)
2324

2425
class AdvancedRAGChat:
2526
def __init__(
@@ -40,13 +41,17 @@ def __init__(
4041
self.query_prompt_template = open(current_dir / "prompts/query.txt").read()
4142
self.answer_prompt_template = open(current_dir / "prompts/answer.txt").read()
4243

44+
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), before_sleep=before_sleep_log(logger, logging.WARNING))
45+
async def openai_chat_completion(self, *args, **kwargs) -> ChatCompletion:
46+
return await self.openai_chat_client.chat.completions.create(*args, **kwargs)
47+
4348
async def hybrid_search(self, messages, top, vector_search, text_search):
4449
# Generate an optimized keyword search query based on the chat history and the last question
4550
query_messages = copy.deepcopy(messages)
4651
query_messages.insert(0, {"role": "system", "content": self.query_prompt_template})
4752
query_response_token_limit = 500
4853

49-
query_chat_completion: ChatCompletion = await self.openai_chat_client.chat.completions.create(
54+
query_chat_completion: ChatCompletion = await self.openai_chat_completion(
5055
messages=query_messages,
5156
model=self.chat_deployment if self.chat_deployment else self.chat_model,
5257
temperature=0.0,
@@ -110,7 +115,7 @@ async def run(
110115
specify_package_messages.insert(0, {"role": "system", "content": self.specify_package_prompt_template})
111116
specify_package_token_limit = 300
112117

113-
specify_package_chat_completion: ChatCompletion = await self.openai_chat_client.chat.completions.create(
118+
specify_package_chat_completion: ChatCompletion = await self.openai_chat_completion(
114119
messages=specify_package_messages,
115120
model=self.chat_deployment if self.chat_deployment else self.chat_model,
116121
temperature=0.0,
@@ -155,9 +160,9 @@ async def run(
155160
# Build messages for the final chat completion
156161
messages.insert(0, {"role": "system", "content": self.answer_prompt_template})
157162
messages[-1]["content"].append({"type": "text", "text": "\n\nSources:\n" + content})
158-
response_token_limit = 1024
163+
response_token_limit = 4096
159164

160-
chat_completion_response = await self.openai_chat_client.chat.completions.create(
165+
chat_completion_response = await self.openai_chat_completion(
161166
model=self.chat_deployment if self.chat_deployment else self.chat_model,
162167
messages=messages,
163168
temperature=overrides.get("temperature", 0.3),

src/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ email_validator==2.1.1
2121
environs==11.0.0
2222
fastapi==0.111.0
2323
fastapi-cli==0.0.4
24-
-e git+https://github.com/azure-samples/rag-postgres-openai-python@1b189c6a227119d31a3947afc36d229cc0b2ac58#egg=fastapi_app&subdirectory=src
24+
-e git+https://github.com/chatrtham/rag-postgres-openai-python.git@7a91d2ab7d3814bb2ed6286a8b89255405309e94#egg=fastapi_app&subdirectory=src
2525
filelock==3.14.0
2626
frozenlist==1.4.1
2727
gitdb==4.0.11
@@ -76,6 +76,7 @@ smmap==5.0.1
7676
sniffio==1.3.1
7777
SQLAlchemy==2.0.30
7878
starlette==0.37.2
79+
tenacity==8.4.1
7980
tiktoken==0.7.0
8081
tqdm==4.66.4
8182
typer==0.12.3

0 commit comments

Comments
 (0)