1- import pathlib
21from collections .abc import AsyncGenerator
3- from typing import (
4- Any ,
5- )
2+ from typing import Any
63
7- from openai import AsyncAzureOpenAI , AsyncOpenAI
8- from openai .types .chat import ChatCompletion , ChatCompletionMessageParam
4+ from openai import AsyncAzureOpenAI , AsyncOpenAI , AsyncStream
5+ from openai .types .chat import ChatCompletion , ChatCompletionChunk , ChatCompletionMessageParam
96from openai_messages_token_helper import build_messages , get_token_limit
107
11- from .api_models import Message , RAGContext , RetrievalResponse , ThoughtStep
12- from .postgres_searcher import PostgresSearcher
13- from .query_rewriter import build_search_function , extract_search_arguments
8+ from fastapi_app .api_models import Message , RAGContext , RetrievalResponse , ThoughtStep
9+ from fastapi_app .postgres_searcher import PostgresSearcher
10+ from fastapi_app .query_rewriter import build_search_function , extract_search_arguments
11+ from fastapi_app .rag_simple import RAGChatBase
1412
1513
16- class AdvancedRAGChat :
14+ class AdvancedRAGChat ( RAGChatBase ) :
1715 def __init__ (
1816 self ,
1917 * ,
@@ -27,29 +25,21 @@ def __init__(
2725 self .chat_model = chat_model
2826 self .chat_deployment = chat_deployment
2927 self .chat_token_limit = get_token_limit (chat_model , default_to_minimum = True )
30- current_dir = pathlib .Path (__file__ ).parent
31- self .query_prompt_template = open (current_dir / "prompts/query.txt" ).read ()
32- self .answer_prompt_template = open (current_dir / "prompts/answer.txt" ).read ()
3328
3429 async def run (
35- self , messages : list [ChatCompletionMessageParam ], overrides : dict [str , Any ] = {}
36- ) -> RetrievalResponse | AsyncGenerator [dict [str , Any ], None ]:
37- text_search = overrides .get ("retrieval_mode" ) in ["text" , "hybrid" , None ]
38- vector_search = overrides .get ("retrieval_mode" ) in ["vectors" , "hybrid" , None ]
39- top = overrides .get ("top" , 3 )
40-
41- original_user_query = messages [- 1 ]["content" ]
42- if not isinstance (original_user_query , str ):
43- raise ValueError ("The most recent message content must be a string." )
44- past_messages = messages [:- 1 ]
30+ self ,
31+ messages : list [ChatCompletionMessageParam ],
32+ overrides : dict [str , Any ] = {},
33+ ) -> RetrievalResponse :
34+ chat_params = self .get_params (messages , overrides )
4535
4636 # Generate an optimized keyword search query based on the chat history and the last question
4737 query_response_token_limit = 500
4838 query_messages : list [ChatCompletionMessageParam ] = build_messages (
4939 model = self .chat_model ,
5040 system_prompt = self .query_prompt_template ,
51- new_user_content = original_user_query ,
52- past_messages = past_messages ,
41+ new_user_content = chat_params . original_user_query ,
42+ past_messages = chat_params . past_messages ,
5343 max_tokens = self .chat_token_limit - query_response_token_limit , # TODO: count functions
5444 fallback_to_default = True ,
5545 )
@@ -65,14 +55,14 @@ async def run(
6555 tool_choice = "auto" ,
6656 )
6757
68- query_text , filters = extract_search_arguments (original_user_query , chat_completion )
58+ query_text , filters = extract_search_arguments (chat_params . original_user_query , chat_completion )
6959
7060 # Retrieve relevant items from the database with the GPT optimized query
7161 results = await self .searcher .search_and_embed (
7262 query_text ,
73- top = top ,
74- enable_vector_search = vector_search ,
75- enable_text_search = text_search ,
63+ top = chat_params . top ,
64+ enable_vector_search = chat_params . enable_vector_search ,
65+ enable_text_search = chat_params . enable_text_search ,
7666 filters = filters ,
7767 )
7868
@@ -84,8 +74,8 @@ async def run(
8474 contextual_messages : list [ChatCompletionMessageParam ] = build_messages (
8575 model = self .chat_model ,
8676 system_prompt = overrides .get ("prompt_template" ) or self .answer_prompt_template ,
87- new_user_content = original_user_query + "\n \n Sources:\n " + content ,
88- past_messages = past_messages ,
77+ new_user_content = chat_params . original_user_query + "\n \n Sources:\n " + content ,
78+ past_messages = chat_params . past_messages ,
8979 max_tokens = self .chat_token_limit - response_token_limit ,
9080 fallback_to_default = True ,
9181 )
@@ -99,6 +89,7 @@ async def run(
9989 n = 1 ,
10090 stream = False ,
10191 )
92+
10293 first_choice_message = chat_completion_response .choices [0 ].message
10394
10495 return RetrievalResponse (
@@ -119,9 +110,9 @@ async def run(
119110 title = "Search using generated search arguments" ,
120111 description = query_text ,
121112 props = {
122- "top" : top ,
123- "vector_search" : vector_search ,
124- "text_search" : text_search ,
113+ "top" : chat_params . top ,
114+ "vector_search" : chat_params . enable_vector_search ,
115+ "text_search" : chat_params . enable_text_search ,
125116 "filters" : filters ,
126117 },
127118 ),
@@ -141,3 +132,114 @@ async def run(
141132 ],
142133 ),
143134 )
135+
136+ async def run_stream (
137+ self ,
138+ messages : list [ChatCompletionMessageParam ],
139+ overrides : dict [str , Any ] = {},
140+ ) -> AsyncGenerator [RetrievalResponse | Message , None ]:
141+ chat_params = self .get_params (messages , overrides )
142+
143+ # Generate an optimized keyword search query based on the chat history and the last question
144+ query_response_token_limit = 500
145+ query_messages : list [ChatCompletionMessageParam ] = build_messages (
146+ model = self .chat_model ,
147+ system_prompt = self .query_prompt_template ,
148+ new_user_content = chat_params .original_user_query ,
149+ past_messages = chat_params .past_messages ,
150+ max_tokens = self .chat_token_limit - query_response_token_limit , # TODO: count functions
151+ fallback_to_default = True ,
152+ )
153+
154+ chat_completion : ChatCompletion = await self .openai_chat_client .chat .completions .create (
155+ messages = query_messages ,
156+ # Azure OpenAI takes the deployment name as the model name
157+ model = self .chat_deployment if self .chat_deployment else self .chat_model ,
158+ temperature = 0.0 , # Minimize creativity for search query generation
159+ max_tokens = query_response_token_limit , # Setting too low risks malformed JSON, too high risks performance
160+ n = 1 ,
161+ tools = build_search_function (),
162+ tool_choice = "auto" ,
163+ )
164+
165+ query_text , filters = extract_search_arguments (chat_params .original_user_query , chat_completion )
166+
167+ # Retrieve relevant items from the database with the GPT optimized query
168+ results = await self .searcher .search_and_embed (
169+ query_text ,
170+ top = chat_params .top ,
171+ enable_vector_search = chat_params .enable_vector_search ,
172+ enable_text_search = chat_params .enable_text_search ,
173+ filters = filters ,
174+ )
175+
176+ sources_content = [f"[{ (item .id )} ]:{ item .to_str_for_rag ()} \n \n " for item in results ]
177+ content = "\n " .join (sources_content )
178+
179+ # Generate a contextual and content specific answer using the search results and chat history
180+ response_token_limit = 1024
181+ contextual_messages : list [ChatCompletionMessageParam ] = build_messages (
182+ model = self .chat_model ,
183+ system_prompt = overrides .get ("prompt_template" ) or self .answer_prompt_template ,
184+ new_user_content = chat_params .original_user_query + "\n \n Sources:\n " + content ,
185+ past_messages = chat_params .past_messages ,
186+ max_tokens = self .chat_token_limit - response_token_limit ,
187+ fallback_to_default = True ,
188+ )
189+
190+ chat_completion_async_stream : AsyncStream [
191+ ChatCompletionChunk
192+ ] = await self .openai_chat_client .chat .completions .create (
193+ # Azure OpenAI takes the deployment name as the model name
194+ model = self .chat_deployment if self .chat_deployment else self .chat_model ,
195+ messages = contextual_messages ,
196+ temperature = overrides .get ("temperature" , 0.3 ),
197+ max_tokens = response_token_limit ,
198+ n = 1 ,
199+ stream = True ,
200+ )
201+
202+ yield RetrievalResponse (
203+ message = Message (content = "" , role = "assistant" ),
204+ context = RAGContext (
205+ data_points = {item .id : item .to_dict () for item in results },
206+ thoughts = [
207+ ThoughtStep (
208+ title = "Prompt to generate search arguments" ,
209+ description = [str (message ) for message in query_messages ],
210+ props = (
211+ {"model" : self .chat_model , "deployment" : self .chat_deployment }
212+ if self .chat_deployment
213+ else {"model" : self .chat_model }
214+ ),
215+ ),
216+ ThoughtStep (
217+ title = "Search using generated search arguments" ,
218+ description = query_text ,
219+ props = {
220+ "top" : chat_params .top ,
221+ "vector_search" : chat_params .enable_vector_search ,
222+ "text_search" : chat_params .enable_text_search ,
223+ "filters" : filters ,
224+ },
225+ ),
226+ ThoughtStep (
227+ title = "Search results" ,
228+ description = [result .to_dict () for result in results ],
229+ ),
230+ ThoughtStep (
231+ title = "Prompt to generate answer" ,
232+ description = [str (message ) for message in contextual_messages ],
233+ props = (
234+ {"model" : self .chat_model , "deployment" : self .chat_deployment }
235+ if self .chat_deployment
236+ else {"model" : self .chat_model }
237+ ),
238+ ),
239+ ],
240+ ),
241+ )
242+
243+ async for response_chunk in chat_completion_async_stream :
244+ yield Message (content = str (response_chunk .choices [0 ].delta .content ), role = "assistant" )
245+ return
0 commit comments