|
| 1 | +import argparse |
| 2 | +import asyncio |
| 3 | +import logging |
| 4 | +import os |
| 5 | +import time |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import pandas as pd |
| 9 | + |
| 10 | +from dotenv import load_dotenv |
| 11 | +from sqlalchemy import select, text, delete |
| 12 | +from sqlalchemy.ext.asyncio import async_sessionmaker |
| 13 | +from tqdm import tqdm |
| 14 | +from azure.identity.aio import DefaultAzureCredential |
| 15 | + |
| 16 | +from fastapi_app.embeddings import compute_text_embedding |
| 17 | +from fastapi_app.openai_clients import create_openai_embed_client |
| 18 | +from fastapi_app.postgres_engine import ( |
| 19 | + create_postgres_engine_from_args, |
| 20 | + create_postgres_engine_from_env, |
| 21 | +) |
| 22 | +from fastapi_app.postgres_models import Item |
| 23 | + |
| 24 | +load_dotenv() |
| 25 | + |
| 26 | +# Set up logging |
| 27 | +logging.basicConfig(level=logging.INFO) |
| 28 | +logger = logging.getLogger("ragapp") |
| 29 | + |
| 30 | +EMBEDDING_FIELDS = [ |
| 31 | + 'package_name', 'package_picture', 'url', 'installment_month', 'installment_limit', |
| 32 | + 'price_to_reserve_for_this_package', 'shop_name', 'category', 'category_tags', |
| 33 | + 'preview_1_10', 'selling_point', 'meta_keywords', 'brand', 'min_max_age', |
| 34 | + 'locations', 'meta_description', 'price_details', 'package_details', |
| 35 | + 'important_info', 'payment_booking_info', 'general_info', 'early_signs_for_diagnosis', |
| 36 | + 'how_to_diagnose', 'hdcare_summary', 'common_question', 'know_this_disease', |
| 37 | + 'courses_of_action', 'signals_to_proceed_surgery', 'get_to_know_this_surgery', |
| 38 | + 'comparisons', 'getting_ready', 'recovery', 'side_effects', 'review_4_5_stars', |
| 39 | + 'brand_option_in_thai_name', 'faq' |
| 40 | +] |
| 41 | + |
| 42 | +def get_to_str_method(item, field): |
| 43 | + method_name = f"to_str_for_embedding_{field}" |
| 44 | + return getattr(item, method_name, None) |
| 45 | + |
| 46 | +def convert_to_int(value): |
| 47 | + try: |
| 48 | + return int(value) |
| 49 | + except (ValueError, TypeError): |
| 50 | + return None |
| 51 | + |
| 52 | +def convert_to_float(value): |
| 53 | + try: |
| 54 | + return float(value) |
| 55 | + except (ValueError, TypeError): |
| 56 | + return None |
| 57 | + |
| 58 | +def convert_to_str(value): |
| 59 | + if value is None: |
| 60 | + return None |
| 61 | + return str(value) |
| 62 | + |
| 63 | +async def fetch_existing_records(session, batch_size=1000): |
| 64 | + offset = 0 |
| 65 | + existing_records = {} |
| 66 | + while True: |
| 67 | + query = select(Item).offset(offset).limit(batch_size) |
| 68 | + result = await session.execute(query) |
| 69 | + items = result.scalars().all() |
| 70 | + if not items: |
| 71 | + break |
| 72 | + for item in items: |
| 73 | + existing_records[item.url] = item |
| 74 | + offset += batch_size |
| 75 | + logger.info(f"Fetched {len(items)} records, offset now {offset}") |
| 76 | + return existing_records |
| 77 | + |
| 78 | +async def seed_and_update_embeddings(engine): |
| 79 | + start_time = time.time() |
| 80 | + logger.info("Checking if the packages table exists...") |
| 81 | + async with engine.begin() as conn: |
| 82 | + result = await conn.execute( |
| 83 | + text( |
| 84 | + """ |
| 85 | + SELECT EXISTS |
| 86 | + (SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'packages') |
| 87 | + """ |
| 88 | + ) |
| 89 | + ) |
| 90 | + if not result.scalar(): |
| 91 | + logger.error("Packages table does not exist. Please run the database setup script first.") |
| 92 | + return |
| 93 | + |
| 94 | + async with async_sessionmaker(engine, expire_on_commit=False)() as session: |
| 95 | + current_dir = os.path.dirname(os.path.realpath(__file__)) |
| 96 | + csv_path = os.path.join(current_dir, "packages.csv") |
| 97 | + |
| 98 | + try: |
| 99 | + df = pd.read_csv(csv_path, delimiter=',', quotechar='"', escapechar='\\', on_bad_lines='skip', encoding='utf-8') |
| 100 | + except pd.errors.ParserError as e: |
| 101 | + logger.error(f"Error reading CSV file: {e}") |
| 102 | + return |
| 103 | + |
| 104 | + logger.info("CSV file read successfully. Processing data...") |
| 105 | + |
| 106 | + str_columns = df.select_dtypes(include=[object]).columns |
| 107 | + df[str_columns] = df[str_columns].replace({np.nan: None}) |
| 108 | + |
| 109 | + num_columns = df.select_dtypes(include=([np.number])).columns |
| 110 | + df[num_columns] = df[num_columns].replace({np.nan: None}) |
| 111 | + |
| 112 | + records = df.to_dict(orient='records') |
| 113 | + new_records = {record["url"]: record for record in records} |
| 114 | + |
| 115 | + logger.info("Fetching existing records from the database...") |
| 116 | + |
| 117 | + existing_records = await fetch_existing_records(session) |
| 118 | + |
| 119 | + logger.info(f"Fetched {len(existing_records)} existing records.") |
| 120 | + |
| 121 | + azure_credential = DefaultAzureCredential() |
| 122 | + openai_embed_client, openai_embed_model, openai_embed_dimensions = await create_openai_embed_client(azure_credential) |
| 123 | + |
| 124 | + logger.info("Starting to insert, update, or delete records in the database...") |
| 125 | + |
| 126 | + for url, record in tqdm(new_records.items(), desc="Processing new records"): |
| 127 | + try: |
| 128 | + record["id"] = convert_to_int(record.get("id")) |
| 129 | + record["price"] = convert_to_float(record.get("price")) |
| 130 | + record["cash_discount"] = convert_to_float(record.get("cash_discount")) |
| 131 | + record["brand_ranking_position"] = convert_to_int(record.get("brand_ranking_position")) |
| 132 | + |
| 133 | + if record["price"] is None: |
| 134 | + logger.error(f"Skipping record with invalid price: {record}") |
| 135 | + continue |
| 136 | + |
| 137 | + existing_item = existing_records.get(url) |
| 138 | + |
| 139 | + if existing_item: |
| 140 | + # Update only the price if there is a change |
| 141 | + if existing_item.price != record["price"]: |
| 142 | + existing_item.price = record["price"] |
| 143 | + session.merge(existing_item) |
| 144 | + await session.commit() |
| 145 | + logger.info(f"Updated price for existing record with URL {url}") |
| 146 | + else: |
| 147 | + # Insert new item |
| 148 | + item_data = {key: value for key, value in record.items() if key in Item.__table__.columns} |
| 149 | + for field in EMBEDDING_FIELDS: |
| 150 | + item_data[f'embedding_{field}'] = None |
| 151 | + |
| 152 | + for key, value in item_data.items(): |
| 153 | + if key not in ["id", "price", "cash_discount", "brand_ranking_position"]: |
| 154 | + item_data[key] = convert_to_str(value) |
| 155 | + |
| 156 | + item = Item(**item_data) |
| 157 | + |
| 158 | + # Generate embeddings for the new item |
| 159 | + for field in EMBEDDING_FIELDS: |
| 160 | + to_str_method = get_to_str_method(item, field) |
| 161 | + if to_str_method: |
| 162 | + field_value = to_str_method() |
| 163 | + if field_value: |
| 164 | + try: |
| 165 | + embedding = await compute_text_embedding( |
| 166 | + field_value, |
| 167 | + openai_client=openai_embed_client, |
| 168 | + embed_model=openai_embed_model, |
| 169 | + embedding_dimensions=openai_embed_dimensions, |
| 170 | + ) |
| 171 | + setattr(item, f'embedding_{field}', embedding) |
| 172 | + logger.info(f"Updated embedding for {field} of item {item.id}") |
| 173 | + except Exception as e: |
| 174 | + logger.error(f"Error updating embedding for {field} of item {item.id}: {e}") |
| 175 | + |
| 176 | + session.merge(item) |
| 177 | + await session.commit() |
| 178 | + logger.info(f"Inserted new record with URL {url}") |
| 179 | + |
| 180 | + except Exception as e: |
| 181 | + logger.error(f"Error processing record with URL {url}: {e}") |
| 182 | + await session.rollback() |
| 183 | + continue |
| 184 | + |
| 185 | + # Delete rows that are not in the new CSV |
| 186 | + for url in tqdm(existing_records.keys() - new_records.keys(), desc="Deleting outdated records"): |
| 187 | + try: |
| 188 | + await session.execute(delete(Item).where(Item.url == url)) |
| 189 | + await session.commit() |
| 190 | + logger.info(f"Deleted outdated record with URL {url}") |
| 191 | + except Exception as e: |
| 192 | + logger.error(f"Error deleting record with URL {url}: {e}") |
| 193 | + await session.rollback() |
| 194 | + |
| 195 | + logger.info("All records processed successfully.") |
| 196 | + end_time = time.time() |
| 197 | + elapsed_time = end_time - start_time |
| 198 | + logger.info(f"Total time taken: {elapsed_time:.2f} seconds") |
| 199 | + |
| 200 | +async def main(): |
| 201 | + parser = argparse.ArgumentParser(description="Seed database with CSV data") |
| 202 | + parser.add_argument("--host", type=str, help="Postgres host") |
| 203 | + parser.add_argument("--username", type=str, help="Postgres username") |
| 204 | + parser.add_argument("--password", type=str, help="Postgres password") |
| 205 | + parser.add_argument("--database", type=str, help="Postgres database") |
| 206 | + parser.add_argument("--sslmode", type=str, help="Postgres sslmode") |
| 207 | + |
| 208 | + args = parser.parse_args() |
| 209 | + if args.host is None: |
| 210 | + engine = await create_postgres_engine_from_env() |
| 211 | + else: |
| 212 | + engine = await create_postgres_engine_from_args(args) |
| 213 | + |
| 214 | + await seed_and_update_embeddings(engine) |
| 215 | + await engine.dispose() |
| 216 | + |
| 217 | +if __name__ == "__main__": |
| 218 | + logging.basicConfig(level=logging.WARNING) |
| 219 | + logger.setLevel(logging.INFO) |
| 220 | + load_dotenv(override=True) |
| 221 | + asyncio.run(main()) |
0 commit comments