1717
1818
1919class AdvancedRAGChat :
20-
2120 def __init__ (
2221 self ,
2322 * ,
@@ -26,7 +25,8 @@ def __init__(
2625 chat_model : str ,
2726 chat_deployment : str | None , # Not needed for non-Azure OpenAI
2827 openai_embed_client : AsyncOpenAI ,
29- embed_deployment : str | None , # Not needed for non-Azure OpenAI or for retrieval_mode="text"
28+ embed_deployment : str
29+ | None , # Not needed for non-Azure OpenAI or for retrieval_mode="text"
3030 embed_model : str ,
3131 embed_dimensions : int ,
3232 ):
@@ -46,7 +46,6 @@ def __init__(
4646 async def run (
4747 self , messages : list [dict ], overrides : dict [str , Any ] = {}
4848 ) -> dict [str , Any ] | AsyncGenerator [dict [str , Any ], None ]:
49-
5049 text_search = overrides .get ("retrieval_mode" ) in ["text" , "hybrid" , None ]
5150 vector_search = overrides .get ("retrieval_mode" ) in ["vectors" , "hybrid" , None ]
5251 top = overrides .get ("top" , 3 )
@@ -61,7 +60,8 @@ async def run(
6160 system_prompt = self .query_prompt_template ,
6261 new_user_content = original_user_query ,
6362 past_messages = past_messages ,
64- max_tokens = self .chat_token_limit - query_response_token_limit , # TODO: count functions
63+ max_tokens = self .chat_token_limit
64+ - query_response_token_limit , # TODO: count functions
6565 fallback_to_default = True ,
6666 )
6767
@@ -70,7 +70,7 @@ async def run(
7070 # Azure OpenAI takes the deployment name as the model name
7171 model = self .chat_deployment if self .chat_deployment else self .chat_model ,
7272 temperature = 0.0 , # Minimize creativity for search query generation
73- max_tokens = query_response_token_limit , # Setting too low risks malformed JSON, setting too high may affect performance
73+ max_tokens = query_response_token_limit , # Setting too low risks malformed JSON, too high risks performance
7474 n = 1 ,
7575 tools = build_search_function (),
7676 tool_choice = "auto" ,
@@ -93,14 +93,17 @@ async def run(
9393
9494 results = await self .searcher .search (query_text , vector , top , filters )
9595
96- sources_content = [f"[{ (item .id )} ]:{ item .to_str_for_rag ()} \n \n " for item in results ]
96+ sources_content = [
97+ f"[{ (item .id )} ]:{ item .to_str_for_rag ()} \n \n " for item in results
98+ ]
9799 content = "\n " .join (sources_content )
98100
99101 # Generate a contextual and content specific answer using the search results and chat history
100102 response_token_limit = 1024
101103 messages = build_messages (
102104 model = self .chat_model ,
103- system_prompt = overrides .get ("prompt_template" ) or self .answer_prompt_template ,
105+ system_prompt = overrides .get ("prompt_template" )
106+ or self .answer_prompt_template ,
104107 new_user_content = original_user_query + "\n \n Sources:\n " + content ,
105108 past_messages = past_messages ,
106109 max_tokens = self .chat_token_limit - response_token_limit ,
0 commit comments