Skip to content

Commit 293216f

Browse files
committed
Update embedding field names in update_embeddings.py
1 parent 49dd055 commit 293216f

File tree

2 files changed

+226
-6
lines changed

2 files changed

+226
-6
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import argparse
2+
import asyncio
3+
import logging
4+
import os
5+
import time
6+
7+
import numpy as np
8+
import pandas as pd
9+
10+
from dotenv import load_dotenv
11+
from sqlalchemy import select, text, delete
12+
from sqlalchemy.ext.asyncio import async_sessionmaker
13+
from tqdm import tqdm
14+
from azure.identity.aio import DefaultAzureCredential
15+
16+
from fastapi_app.embeddings import compute_text_embedding
17+
from fastapi_app.openai_clients import create_openai_embed_client
18+
from fastapi_app.postgres_engine import (
19+
create_postgres_engine_from_args,
20+
create_postgres_engine_from_env,
21+
)
22+
from fastapi_app.postgres_models import Item
23+
24+
load_dotenv()
25+
26+
# Set up logging
27+
logging.basicConfig(level=logging.INFO)
28+
logger = logging.getLogger("ragapp")
29+
30+
EMBEDDING_FIELDS = [
31+
'package_name', 'package_picture', 'url', 'installment_month', 'installment_limit',
32+
'price_to_reserve_for_this_package', 'shop_name', 'category', 'category_tags',
33+
'preview_1_10', 'selling_point', 'meta_keywords', 'brand', 'min_max_age',
34+
'locations', 'meta_description', 'price_details', 'package_details',
35+
'important_info', 'payment_booking_info', 'general_info', 'early_signs_for_diagnosis',
36+
'how_to_diagnose', 'hdcare_summary', 'common_question', 'know_this_disease',
37+
'courses_of_action', 'signals_to_proceed_surgery', 'get_to_know_this_surgery',
38+
'comparisons', 'getting_ready', 'recovery', 'side_effects', 'review_4_5_stars',
39+
'brand_option_in_thai_name', 'faq'
40+
]
41+
42+
def get_to_str_method(item, field):
43+
method_name = f"to_str_for_embedding_{field}"
44+
return getattr(item, method_name, None)
45+
46+
def convert_to_int(value):
47+
try:
48+
return int(value)
49+
except (ValueError, TypeError):
50+
return None
51+
52+
def convert_to_float(value):
53+
try:
54+
return float(value)
55+
except (ValueError, TypeError):
56+
return None
57+
58+
def convert_to_str(value):
59+
if value is None:
60+
return None
61+
return str(value)
62+
63+
async def fetch_existing_records(session, batch_size=1000):
64+
offset = 0
65+
existing_records = {}
66+
while True:
67+
query = select(Item).offset(offset).limit(batch_size)
68+
result = await session.execute(query)
69+
items = result.scalars().all()
70+
if not items:
71+
break
72+
for item in items:
73+
existing_records[item.url] = item
74+
offset += batch_size
75+
logger.info(f"Fetched {len(items)} records, offset now {offset}")
76+
return existing_records
77+
78+
async def seed_and_update_embeddings(engine):
79+
start_time = time.time()
80+
logger.info("Checking if the packages table exists...")
81+
async with engine.begin() as conn:
82+
result = await conn.execute(
83+
text(
84+
"""
85+
SELECT EXISTS
86+
(SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'packages')
87+
"""
88+
)
89+
)
90+
if not result.scalar():
91+
logger.error("Packages table does not exist. Please run the database setup script first.")
92+
return
93+
94+
async with async_sessionmaker(engine, expire_on_commit=False)() as session:
95+
current_dir = os.path.dirname(os.path.realpath(__file__))
96+
csv_path = os.path.join(current_dir, "packages.csv")
97+
98+
try:
99+
df = pd.read_csv(csv_path, delimiter=',', quotechar='"', escapechar='\\', on_bad_lines='skip', encoding='utf-8')
100+
except pd.errors.ParserError as e:
101+
logger.error(f"Error reading CSV file: {e}")
102+
return
103+
104+
logger.info("CSV file read successfully. Processing data...")
105+
106+
str_columns = df.select_dtypes(include=[object]).columns
107+
df[str_columns] = df[str_columns].replace({np.nan: None})
108+
109+
num_columns = df.select_dtypes(include=([np.number])).columns
110+
df[num_columns] = df[num_columns].replace({np.nan: None})
111+
112+
records = df.to_dict(orient='records')
113+
new_records = {record["url"]: record for record in records}
114+
115+
logger.info("Fetching existing records from the database...")
116+
117+
existing_records = await fetch_existing_records(session)
118+
119+
logger.info(f"Fetched {len(existing_records)} existing records.")
120+
121+
azure_credential = DefaultAzureCredential()
122+
openai_embed_client, openai_embed_model, openai_embed_dimensions = await create_openai_embed_client(azure_credential)
123+
124+
logger.info("Starting to insert, update, or delete records in the database...")
125+
126+
for url, record in tqdm(new_records.items(), desc="Processing new records"):
127+
try:
128+
record["id"] = convert_to_int(record.get("id"))
129+
record["price"] = convert_to_float(record.get("price"))
130+
record["cash_discount"] = convert_to_float(record.get("cash_discount"))
131+
record["brand_ranking_position"] = convert_to_int(record.get("brand_ranking_position"))
132+
133+
if record["price"] is None:
134+
logger.error(f"Skipping record with invalid price: {record}")
135+
continue
136+
137+
existing_item = existing_records.get(url)
138+
139+
if existing_item:
140+
# Update only the price if there is a change
141+
if existing_item.price != record["price"]:
142+
existing_item.price = record["price"]
143+
session.merge(existing_item)
144+
await session.commit()
145+
logger.info(f"Updated price for existing record with URL {url}")
146+
else:
147+
# Insert new item
148+
item_data = {key: value for key, value in record.items() if key in Item.__table__.columns}
149+
for field in EMBEDDING_FIELDS:
150+
item_data[f'embedding_{field}'] = None
151+
152+
for key, value in item_data.items():
153+
if key not in ["id", "price", "cash_discount", "brand_ranking_position"]:
154+
item_data[key] = convert_to_str(value)
155+
156+
item = Item(**item_data)
157+
158+
# Generate embeddings for the new item
159+
for field in EMBEDDING_FIELDS:
160+
to_str_method = get_to_str_method(item, field)
161+
if to_str_method:
162+
field_value = to_str_method()
163+
if field_value:
164+
try:
165+
embedding = await compute_text_embedding(
166+
field_value,
167+
openai_client=openai_embed_client,
168+
embed_model=openai_embed_model,
169+
embedding_dimensions=openai_embed_dimensions,
170+
)
171+
setattr(item, f'embedding_{field}', embedding)
172+
logger.info(f"Updated embedding for {field} of item {item.id}")
173+
except Exception as e:
174+
logger.error(f"Error updating embedding for {field} of item {item.id}: {e}")
175+
176+
session.merge(item)
177+
await session.commit()
178+
logger.info(f"Inserted new record with URL {url}")
179+
180+
except Exception as e:
181+
logger.error(f"Error processing record with URL {url}: {e}")
182+
await session.rollback()
183+
continue
184+
185+
# Delete rows that are not in the new CSV
186+
for url in tqdm(existing_records.keys() - new_records.keys(), desc="Deleting outdated records"):
187+
try:
188+
await session.execute(delete(Item).where(Item.url == url))
189+
await session.commit()
190+
logger.info(f"Deleted outdated record with URL {url}")
191+
except Exception as e:
192+
logger.error(f"Error deleting record with URL {url}: {e}")
193+
await session.rollback()
194+
195+
logger.info("All records processed successfully.")
196+
end_time = time.time()
197+
elapsed_time = end_time - start_time
198+
logger.info(f"Total time taken: {elapsed_time:.2f} seconds")
199+
200+
async def main():
201+
parser = argparse.ArgumentParser(description="Seed database with CSV data")
202+
parser.add_argument("--host", type=str, help="Postgres host")
203+
parser.add_argument("--username", type=str, help="Postgres username")
204+
parser.add_argument("--password", type=str, help="Postgres password")
205+
parser.add_argument("--database", type=str, help="Postgres database")
206+
parser.add_argument("--sslmode", type=str, help="Postgres sslmode")
207+
208+
args = parser.parse_args()
209+
if args.host is None:
210+
engine = await create_postgres_engine_from_env()
211+
else:
212+
engine = await create_postgres_engine_from_args(args)
213+
214+
await seed_and_update_embeddings(engine)
215+
await engine.dispose()
216+
217+
if __name__ == "__main__":
218+
logging.basicConfig(level=logging.WARNING)
219+
logger.setLevel(logging.INFO)
220+
load_dotenv(override=True)
221+
asyncio.run(main())

src/fastapi_app/update_embeddings.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,11 @@
2222
'package_name', 'package_picture', 'url', 'installment_month', 'installment_limit',
2323
'price_to_reserve_for_this_package', 'shop_name', 'category', 'category_tags',
2424
'preview_1_10', 'selling_point', 'meta_keywords', 'brand', 'min_max_age',
25-
'locations_time_open_close_how_to_transport_parking_google_maps', 'meta_description',
26-
'price_details', 'package_details', 'important_info', 'payment_booking_info',
27-
'general_info', 'early_signs_for_diagnosis', 'how_to_diagnose', 'hdcare_summary',
28-
'common_question', 'know_this_disease', 'courses_of_action', 'signals_to_proceed_surgery',
29-
'get_to_know_this_surgery', 'comparisons', 'getting_ready', 'recovery',
30-
'side_effects', 'review_4_5_stars', 'brand_option_in_thai_name', 'faq'
25+
'locations', 'meta_description','price_details', 'package_details', 'important_info',
26+
'payment_booking_info', 'general_info', 'early_signs_for_diagnosis', 'how_to_diagnose',
27+
'hdcare_summary', 'common_question', 'know_this_disease', 'courses_of_action',
28+
'signals_to_proceed_surgery', 'get_to_know_this_surgery', 'comparisons', 'getting_ready',
29+
'recovery', 'side_effects', 'review_4_5_stars', 'brand_option_in_thai_name', 'faq'
3130
]
3231

3332
def get_to_str_method(item, field):

0 commit comments

Comments
 (0)