11import copy
2+ import logging
23import pathlib
34from collections .abc import AsyncGenerator
4- from typing import (
5- Any ,
6- )
5+ from typing import Any
76
87from openai import AsyncOpenAI
9- from openai .types .chat import (
10- ChatCompletion ,
11- )
8+ from openai .types .chat import ChatCompletion
129from openai_messages_token_helper import get_token_limit
10+ from tenacity import before_sleep_log , retry , stop_after_attempt , wait_random_exponential
1311
1412from .api_models import ThoughtStep
1513from .postgres_searcher import PostgresSearcher
2018 handle_specify_package_function_call ,
2119)
2220
21+ # Configure logging
22+ logging .basicConfig (level = logging .INFO )
23+ logger = logging .getLogger (__name__ )
2324
2425class AdvancedRAGChat :
2526 def __init__ (
@@ -40,13 +41,17 @@ def __init__(
4041 self .query_prompt_template = open (current_dir / "prompts/query.txt" ).read ()
4142 self .answer_prompt_template = open (current_dir / "prompts/answer.txt" ).read ()
4243
44+ @retry (wait = wait_random_exponential (min = 1 , max = 60 ), stop = stop_after_attempt (6 ), before_sleep = before_sleep_log (logger , logging .WARNING ))
45+ async def openai_chat_completion (self , * args , ** kwargs ) -> ChatCompletion :
46+ return await self .openai_chat_client .chat .completions .create (* args , ** kwargs )
47+
4348 async def hybrid_search (self , messages , top , vector_search , text_search ):
4449 # Generate an optimized keyword search query based on the chat history and the last question
4550 query_messages = copy .deepcopy (messages )
4651 query_messages .insert (0 , {"role" : "system" , "content" : self .query_prompt_template })
4752 query_response_token_limit = 500
4853
49- query_chat_completion : ChatCompletion = await self .openai_chat_client . chat . completions . create (
54+ query_chat_completion : ChatCompletion = await self .openai_chat_completion (
5055 messages = query_messages ,
5156 model = self .chat_deployment if self .chat_deployment else self .chat_model ,
5257 temperature = 0.0 ,
@@ -110,7 +115,7 @@ async def run(
110115 specify_package_messages .insert (0 , {"role" : "system" , "content" : self .specify_package_prompt_template })
111116 specify_package_token_limit = 300
112117
113- specify_package_chat_completion : ChatCompletion = await self .openai_chat_client . chat . completions . create (
118+ specify_package_chat_completion : ChatCompletion = await self .openai_chat_completion (
114119 messages = specify_package_messages ,
115120 model = self .chat_deployment if self .chat_deployment else self .chat_model ,
116121 temperature = 0.0 ,
@@ -155,9 +160,9 @@ async def run(
155160 # Build messages for the final chat completion
156161 messages .insert (0 , {"role" : "system" , "content" : self .answer_prompt_template })
157162 messages [- 1 ]["content" ].append ({"type" : "text" , "text" : "\n \n Sources:\n " + content })
158- response_token_limit = 1024
163+ response_token_limit = 4096
159164
160- chat_completion_response = await self .openai_chat_client . chat . completions . create (
165+ chat_completion_response = await self .openai_chat_completion (
161166 model = self .chat_deployment if self .chat_deployment else self .chat_model ,
162167 messages = messages ,
163168 temperature = overrides .get ("temperature" , 0.3 ),
0 commit comments