Skip to content

Commit 7a91d2a

Browse files
committed
Add use_or parameter to build_filter_clause method + Fix run method in AdvancedRAGChat
1 parent d2a59ff commit 7a91d2a

File tree

2 files changed

+84
-76
lines changed

2 files changed

+84
-76
lines changed

src/fastapi_app/postgres_searcher.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ def __init__(
2222
self.embed_deployment = embed_deployment
2323
self.embed_dimensions = embed_dimensions
2424

25-
def build_filter_clause(self, filters) -> tuple[str, str]:
25+
def build_filter_clause(self, filters, use_or=False) -> tuple[str, str]:
2626
if filters is None:
2727
return "", ""
2828
filter_clauses = []
2929
for filter in filters:
3030
if isinstance(filter["value"], str):
3131
filter["value"] = f"'{filter['value']}'"
3232
filter_clauses.append(f"{filter['column']} {filter['comparison_operator']} {filter['value']}")
33-
filter_clause = " AND ".join(filter_clauses)
33+
filter_clause = f" {'OR' if use_or else 'AND'} ".join(filter_clauses)
3434
if len(filter_clause) > 0:
3535
return f"WHERE {filter_clause}", f"AND {filter_clause}"
3636
return "", ""
@@ -290,11 +290,11 @@ async def simple_sql_search(
290290
"""
291291
Search items by simple SQL query with filters.
292292
"""
293-
filter_clause_where, _ = self.build_filter_clause(filters)
293+
filter_clause_where, _ = self.build_filter_clause(filters, use_or=True)
294294
sql = f"""
295295
SELECT id FROM packages
296296
{filter_clause_where}
297-
LIMIT 1
297+
LIMIT 2
298298
"""
299299

300300
async with self.async_session_maker() as session:

src/fastapi_app/rag_advanced.py

Lines changed: 80 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from .postgres_searcher import PostgresSearcher
1616
from .query_rewriter import (
1717
build_hybrid_search_function,
18-
extract_search_arguments,
1918
build_specify_package_function,
20-
handle_specify_package_function_call
19+
extract_search_arguments,
20+
handle_specify_package_function_call,
2121
)
2222

2323

@@ -40,6 +40,58 @@ def __init__(
4040
self.query_prompt_template = open(current_dir / "prompts/query.txt").read()
4141
self.answer_prompt_template = open(current_dir / "prompts/answer.txt").read()
4242

43+
async def hybrid_search(self, messages, top, vector_search, text_search):
44+
# Generate an optimized keyword search query based on the chat history and the last question
45+
query_messages = copy.deepcopy(messages)
46+
query_messages.insert(0, {"role": "system", "content": self.query_prompt_template})
47+
query_response_token_limit = 500
48+
49+
query_chat_completion: ChatCompletion = await self.openai_chat_client.chat.completions.create(
50+
messages=query_messages,
51+
model=self.chat_deployment if self.chat_deployment else self.chat_model,
52+
temperature=0.0,
53+
max_tokens=query_response_token_limit,
54+
n=1,
55+
tools=build_hybrid_search_function(),
56+
tool_choice="auto",
57+
)
58+
59+
query_text, filters = extract_search_arguments(query_chat_completion)
60+
61+
# Retrieve relevant items from the database with the GPT optimized query
62+
results = await self.searcher.search_and_embed(
63+
query_text,
64+
top=top,
65+
enable_vector_search=vector_search,
66+
enable_text_search=text_search,
67+
filters=filters,
68+
)
69+
70+
sources_content = [f"[{(item.id)}]:{item.to_str_for_broad_rag()}\n\n" for item in results]
71+
72+
thought_steps = [
73+
ThoughtStep(
74+
title="Prompt to generate search arguments",
75+
description=[str(message) for message in query_messages],
76+
props={"model": self.chat_model, "deployment": self.chat_deployment} if self.chat_deployment else {"model": self.chat_model}
77+
),
78+
ThoughtStep(
79+
title="Generated search arguments",
80+
description=query_text,
81+
props={"filters": filters}
82+
),
83+
ThoughtStep(
84+
title="Hybrid Search results",
85+
description=[result.to_dict() for result in results],
86+
props={
87+
"top": top,
88+
"vector_search": vector_search,
89+
"text_search": text_search
90+
}
91+
)
92+
]
93+
return sources_content, thought_steps
94+
4395
async def run(
4496
self, messages: list[dict], overrides: dict[str, Any] = {}
4597
) -> dict[str, Any] | AsyncGenerator[dict[str, Any], None]:
@@ -69,78 +121,34 @@ async def run(
69121

70122
specify_package_filters = handle_specify_package_function_call(specify_package_chat_completion)
71123

72-
if specify_package_filters:
73-
# Pass specify_package_filters to simple SQL search function
124+
if specify_package_filters: # Simple SQL search
74125
results = await self.searcher.simple_sql_search(filters=specify_package_filters)
75-
sources_content = [f"[{(item.id)}]:{item.to_str_for_narrow_rag()}\n\n" for item in results]
76-
77-
thought_steps = [
78-
ThoughtStep(
79-
title="Prompt to specify package",
80-
description=[str(message) for message in specify_package_messages],
81-
props={"model": self.chat_model, "deployment": self.chat_deployment} if self.chat_deployment else {"model": self.chat_model}
82-
),
83-
ThoughtStep(
84-
title="Specified package filters",
85-
description=specify_package_filters,
86-
props={}
87-
),
88-
ThoughtStep(
89-
title="SQL search results",
90-
description=[result.to_dict() for result in results],
91-
props={}
92-
)
93-
]
94-
else:
95-
# Generate an optimized keyword search query based on the chat history and the last question
96-
query_messages = copy.deepcopy(messages)
97-
query_messages.insert(0, {"role": "system", "content": self.query_prompt_template})
98-
query_response_token_limit = 500
99-
100-
query_chat_completion: ChatCompletion = await self.openai_chat_client.chat.completions.create(
101-
messages=query_messages,
102-
model=self.chat_deployment if self.chat_deployment else self.chat_model,
103-
temperature=0.0,
104-
max_tokens=query_response_token_limit,
105-
n=1,
106-
tools=build_hybrid_search_function(),
107-
tool_choice="auto",
108-
)
109-
110-
query_text, filters = extract_search_arguments(query_chat_completion)
111-
112-
# Retrieve relevant items from the database with the GPT optimized query
113-
results = await self.searcher.search_and_embed(
114-
query_text,
115-
top=top,
116-
enable_vector_search=vector_search,
117-
enable_text_search=text_search,
118-
filters=filters,
119-
)
120126

121-
sources_content = [f"[{(item.id)}]:{item.to_str_for_broad_rag()}\n\n" for item in results]
122-
123-
thought_steps = [
124-
ThoughtStep(
125-
title="Prompt to generate search arguments",
126-
description=[str(message) for message in query_messages],
127-
props={"model": self.chat_model, "deployment": self.chat_deployment} if self.chat_deployment else {"model": self.chat_model}
128-
),
129-
ThoughtStep(
130-
title="Generated search arguments",
131-
description=query_text,
132-
props={"filters": filters}
133-
),
134-
ThoughtStep(
135-
title="Hybrid Search results",
136-
description=[result.to_dict() for result in results],
137-
props={
138-
"top": top,
139-
"vector_search": vector_search,
140-
"text_search": text_search
141-
}
142-
)
143-
]
127+
if results:
128+
sources_content = [f"[{(item.id)}]:{item.to_str_for_narrow_rag()}\n\n" for item in results]
129+
130+
thought_steps = [
131+
ThoughtStep(
132+
title="Prompt to specify package",
133+
description=[str(message) for message in specify_package_messages],
134+
props={"model": self.chat_model, "deployment": self.chat_deployment} if self.chat_deployment else {"model": self.chat_model}
135+
),
136+
ThoughtStep(
137+
title="Specified package filters",
138+
description=specify_package_filters,
139+
props={}
140+
),
141+
ThoughtStep(
142+
title="SQL search results",
143+
description=[result.to_dict() for result in results],
144+
props={}
145+
)
146+
]
147+
else:
148+
# No results found with SQL search, fall back to the hybrid search
149+
sources_content, thought_steps = await self.hybrid_search(messages, top, vector_search, text_search)
150+
else: # Hybrid search
151+
sources_content, thought_steps = await self.hybrid_search(messages, top, vector_search, text_search)
144152

145153
content = "\n".join(sources_content)
146154

0 commit comments

Comments
 (0)