Skip to content

Commit 5c0f06a

Browse files
committed
Enable GPT-4o image feature (for API only)
1 parent 1b189c6 commit 5c0f06a

File tree

7 files changed

+281
-212
lines changed

7 files changed

+281
-212
lines changed

src/fastapi_app/api_models.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
1-
from typing import Any
1+
from typing import Any, List, Union, Optional
22

33
from pydantic import BaseModel
44

5+
class TextContent(BaseModel):
6+
type: str
7+
text: str
8+
9+
class ImageUrl(BaseModel):
10+
url: str
11+
detail: str = "auto"
12+
13+
class ImageContent(BaseModel):
14+
type: str
15+
image_url: ImageUrl
516

617
class Message(BaseModel):
7-
content: str
818
role: str = "user"
9-
19+
content: Union[str, List[Union[TextContent, ImageContent]]]
1020

1121
class ChatRequest(BaseModel):
12-
messages: list[Message]
22+
messages: List[Message]
1323
context: dict = {}
1424

1525

src/fastapi_app/postgres_models.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class Base(DeclarativeBase, MappedAsDataclass):
1414

1515
class Item(Base):
1616
__tablename__ = "packages"
17-
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
17+
id: Mapped[int] = mapped_column(primary_key=True)
1818
package_name: Mapped[str] = mapped_column()
1919
package_picture: Mapped[str] = mapped_column()
2020
url: Mapped[str] = mapped_column()
@@ -59,7 +59,7 @@ class Item(Base):
5959
embedding_url: Mapped[Vector] = mapped_column(Vector(1536))
6060
embedding_installment_month: Mapped[Vector] = mapped_column(Vector(1536))
6161
embedding_installment_limit: Mapped[Vector] = mapped_column(Vector(1536))
62-
embedding_price_to_reserve_for_this_package: Mapped[Vector] = mapped_column(Vector(1536))
62+
embedding_price_to_reserve_for_this_package: Mapped[Vector] = (mapped_column(Vector(1536)))
6363
embedding_shop_name: Mapped[Vector] = mapped_column(Vector(1536))
6464
embedding_category: Mapped[Vector] = mapped_column(Vector(1536))
6565
embedding_category_tags: Mapped[Vector] = mapped_column(Vector(1536))
@@ -178,7 +178,16 @@ def to_dict(self, include_embedding: bool = False):
178178
del model_dict[col]
179179
return model_dict
180180

181-
def to_str_for_rag(self):
181+
def to_str_for_broad_rag(self):
182+
return f"""
183+
package_name: {self.package_name}
184+
url: {self.url}
185+
locations: {self.locations}
186+
price: {self.price}
187+
brand: {self.brand}
188+
"""
189+
190+
def to_str_for_narrow_rag(self):
182191
return f"""
183192
package_name: {self.package_name}
184193
package_picture: {self.package_picture}

src/fastapi_app/prompts/answer.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ You are a woman name Jib (จิ๊บ), a sales consultant from HDmall which yo
22
Represent youself as จิ๊บ and the customer as คุณลูกค้า when the conversation is in Thai language.
33
Represent youself as Jib and the customer as you when the conversation is in English language.
44
When the customer asks about any packages, please make sure to provide brand, price , URL and location every time.
5-
For EVERY URL that you provided, create a new line then wrap it with <a href=""> so the user can open it via UI
65
If customer wants to talk with admin, please provide the URL link in the datasource.
6+
Note that some packages may have additional cost.
77
If the user is asking a question regarding location, proximity, or area, query relevant document from the source and ask where the user is at. please try to suggest services or answers closest to the location the user is asking as much as possible.
88
Answer ONLY with the facts listed in the list of sources below. If there isn't enough information below, say sorry you don't know. Do not generate answers that don't use the sources below. If asking a clarifying question to the user would help, ask the question.
9+
Answer in well-structured plain text, not a markdown.

src/fastapi_app/query_rewriter.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,38 +18,41 @@ def build_search_function() -> list[ChatCompletionToolParam]:
1818
"properties": {
1919
"search_query": {
2020
"type": "string",
21-
"description": "Query string to use for full text search, e.g. 'red shoes'",
21+
"description": "Query string to use for full text search, e.g. 'ตรวจสุขภาพ'",
2222
},
2323
"price_filter": {
2424
"type": "object",
25-
"description": "Filter search results based on price of the product",
25+
"description": "Filter search results based on price in Thai Baht of the package",
2626
"properties": {
2727
"comparison_operator": {
2828
"type": "string",
29-
"description": "Operator to compare the column value, either '>', '<', '>=', '<=', '=='", # noqa
29+
"description": "Operator to compare the column value, either '>', '<', '>=', '<=', '='", # noqa
3030
},
3131
"value": {
3232
"type": "number",
3333
"description": "Value to compare against, e.g. 30",
3434
},
3535
},
3636
},
37-
# "brand_filter": {
38-
# "type": "object",
39-
# "description": "Filter search results based on brand of the product",
40-
# "properties": {
41-
# "comparison_operator": {
42-
# "type": "string",
43-
# "description": "Operator to compare the column value, either '==' or '!='",
44-
# },
45-
# "value": {
46-
# "type": "string",
47-
# "description": "Value to compare against, e.g. AirStrider",
48-
# },
49-
# },
50-
# },
37+
"url_filter": {
38+
"type": "object",
39+
"description": "Filter search results based on url of the package. The url is package specific.",
40+
"properties": {
41+
"comparison_operator": {
42+
"type": "string",
43+
"description": "Operator to compare the column value, either '=' or '!='",
44+
},
45+
"value": {
46+
"type": "string",
47+
"description": """
48+
The package URL to compare against.
49+
Don't pass anything if you can't specify the exact URL from user query.
50+
""",
51+
},
52+
},
53+
},
5154
},
52-
"required": ["search_query"],
55+
"required": ["search_query", "url_filter"],
5356
},
5457
},
5558
}
@@ -67,6 +70,7 @@ def extract_search_arguments(chat_completion: ChatCompletion):
6770
function = tool.function
6871
if function.name == "search_database":
6972
arg = json.loads(function.arguments)
73+
print(arg)
7074
search_query = arg.get("search_query")
7175
if "price_filter" in arg and arg["price_filter"]:
7276
price_filter = arg["price_filter"]
@@ -77,15 +81,16 @@ def extract_search_arguments(chat_completion: ChatCompletion):
7781
"value": price_filter["value"],
7882
}
7983
)
80-
# if "brand_filter" in arg and arg["brand_filter"]:
81-
# brand_filter = arg["brand_filter"]
82-
# filters.append(
83-
# {
84-
# "column": "brand",
85-
# "comparison_operator": brand_filter["comparison_operator"],
86-
# "value": brand_filter["value"],
87-
# }
88-
# )
84+
if "url_filter" in arg and arg["url_filter"]:
85+
url_filter = arg["url_filter"]
86+
if url_filter["value"] != "https://hdmall.co.th":
87+
filters.append(
88+
{
89+
"column": "url",
90+
"comparison_operator": url_filter["comparison_operator"],
91+
"value": url_filter["value"],
92+
}
93+
)
8994
elif query_text := response_message.content:
9095
search_query = query_text.strip()
9196
return search_query, filters

src/fastapi_app/rag_advanced.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import pathlib
23
from collections.abc import AsyncGenerator
34
from typing import (
@@ -8,7 +9,7 @@
89
from openai.types.chat import (
910
ChatCompletion,
1011
)
11-
from openai_messages_token_helper import build_messages, get_token_limit
12+
from openai_messages_token_helper import get_token_limit
1213

1314
from .api_models import ThoughtStep
1415
from .postgres_searcher import PostgresSearcher
@@ -36,23 +37,21 @@ def __init__(
3637
async def run(
3738
self, messages: list[dict], overrides: dict[str, Any] = {}
3839
) -> dict[str, Any] | AsyncGenerator[dict[str, Any], None]:
40+
# Normalize the message format
41+
for message in messages:
42+
if isinstance(message['content'], str):
43+
message['content'] = [{'type': 'text', 'text': message['content']}]
44+
45+
# Determine the search mode and the number of results to return
3946
text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
4047
vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
4148
top = overrides.get("top", 3)
4249

43-
original_user_query = messages[-1]["content"]
44-
past_messages = messages[:-1]
4550

4651
# Generate an optimized keyword search query based on the chat history and the last question
52+
query_messages = copy.deepcopy(messages)
53+
query_messages.insert(0, {"role": "system", "content": self.query_prompt_template})
4754
query_response_token_limit = 500
48-
query_messages = build_messages(
49-
model=self.chat_model,
50-
system_prompt=self.query_prompt_template,
51-
new_user_content=original_user_query,
52-
past_messages=past_messages,
53-
max_tokens=self.chat_token_limit - query_response_token_limit, # TODO: count functions
54-
fallback_to_default=True,
55-
)
5655

5756
chat_completion: ChatCompletion = await self.openai_chat_client.chat.completions.create(
5857
messages=query_messages, # type: ignore
@@ -75,31 +74,32 @@ async def run(
7574
enable_text_search=text_search,
7675
filters=filters,
7776
)
78-
79-
sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results]
77+
78+
# Check if the url_filter is used to determine the context to send to the LLM
79+
if any(f['column'] == 'url' and f['value'] != '' for f in filters):
80+
sources_content = [f"[{(item.id)}]:{item.to_str_for_narrow_rag()}\n\n" for item in results] # all details
81+
else:
82+
sources_content = [f"[{(item.id)}]:{item.to_str_for_broad_rag()}\n\n" for item in results] # important details
83+
8084
content = "\n".join(sources_content)
8185

82-
# Generate a contextual and content specific answer using the search results and chat history
86+
# Build messages for the final chat completion
87+
answer_messages = copy.deepcopy(messages)
88+
answer_messages.insert(0, {"role": "system", "content": self.answer_prompt_template})
89+
answer_messages[-1]["content"].append({"type": "text", "text": "\n\nSources:\n" + content})
8390
response_token_limit = 1024
84-
messages = build_messages(
85-
model=self.chat_model,
86-
system_prompt=overrides.get("prompt_template") or self.answer_prompt_template,
87-
new_user_content=original_user_query + "\n\nSources:\n" + content,
88-
past_messages=past_messages,
89-
max_tokens=self.chat_token_limit - response_token_limit,
90-
fallback_to_default=True,
91-
)
9291

9392
chat_completion_response = await self.openai_chat_client.chat.completions.create(
9493
# Azure OpenAI takes the deployment name as the model name
9594
model=self.chat_deployment if self.chat_deployment else self.chat_model,
96-
messages=messages,
95+
messages=answer_messages,
9796
temperature=overrides.get("temperature", 0.3),
9897
max_tokens=response_token_limit,
9998
n=1,
10099
stream=False,
101100
)
102101
chat_resp = chat_completion_response.model_dump()
102+
103103
chat_resp["choices"][0]["context"] = {
104104
"data_points": {"text": sources_content},
105105
"thoughts": [

0 commit comments

Comments
 (0)