1+ import argparse
12import json
23import logging
34import os
89from dotenv_azd import load_azd_env
910from openai import AzureOpenAI , OpenAI
1011from openai .types .chat import ChatCompletionToolParam
11- from sqlalchemy import create_engine , select
12+ from sqlalchemy import create_engine , select , func
1213from sqlalchemy .orm import Session
14+ from dotenv import load_dotenv
15+ from jinja2 import Environment , FileSystemLoader
16+ from rich .logging import RichHandler
1317
1418from fastapi_app .postgres_models import Item
1519
1620logger = logging .getLogger ("ragapp" )
1721
1822
23+
1924def qa_pairs_tool (num_questions : int = 1 ) -> ChatCompletionToolParam :
2025 return {
2126 "type" : "function" ,
@@ -47,26 +52,19 @@ def qa_pairs_tool(num_questions: int = 1) -> ChatCompletionToolParam:
4752
4853
4954def source_retriever () -> Generator [str , None , None ]:
50- # Connect to the database
55+ # Connect to the local database
5156 DBHOST = os .environ ["POSTGRES_HOST" ]
5257 DBUSER = os .environ ["POSTGRES_USERNAME" ]
5358 DBPASS = os .environ ["POSTGRES_PASSWORD" ]
5459 DBNAME = os .environ ["POSTGRES_DATABASE" ]
5560 DATABASE_URI = f"postgresql://{ DBUSER } :{ DBPASS } @{ DBHOST } /{ DBNAME } "
5661 engine = create_engine (DATABASE_URI , echo = False )
5762 with Session (engine ) as session :
58- # Fetch all products for a particular type
59- item_types = session .scalars (select (Item .type ).distinct ())
60- for item_type in item_types :
61- records = list (session .scalars (select (Item ).filter (Item .type == item_type ).order_by (Item .id )))
62- logger .info (f"Processing database records for type: { item_type } " )
63- yield "\n \n " .join ([f"## Product ID: [{ record .id } ]\n " + record .to_str_for_rag () for record in records ])
64- # Fetch each item individually
65- # records = list(session.scalars(select(Item).order_by(Item.id)))
66- # for record in records:
67- # logger.info(f"Processing database record: {record.name}")
68- # yield f"## Product ID: [{record.id}]\n" + record.to_str_for_rag()
69- # await self.openai_chat_client.chat.completions.create(
63+ while True :
64+ # Fetch all the rows from the database
65+ random_rows = list (session .scalars (select (Item ).order_by (func .random ())))
66+ logger .info ("Fetched %d random rows" , len (random_rows ))
67+ yield "\n \n " .join ([f"## Row ID: [{ row .id } ]\n " + row .to_str_for_rag () for row in random_rows ])
7068
7169
7270def source_to_text (source ) -> str :
@@ -108,31 +106,36 @@ def get_openai_client() -> tuple[AzureOpenAI | OpenAI, str]:
108106 return openai_client , model
109107
110108
111- def generate_ground_truth_data (num_questions_total : int , num_questions_per_source : int = 5 ):
109+ def generate_ground_truth_data (num_questions_total : int , num_questions_per_source ):
112110 logger .info ("Generating %d questions total" , num_questions_total )
113111 openai_client , model = get_openai_client ()
114112 current_dir = Path (__file__ ).parent
115- generate_prompt = open (current_dir / "generate_prompt.txt" ).read ()
113+
114+ # Load the template from the file system
115+ jinja_file_loader = FileSystemLoader (current_dir )
116+ jinja_env = Environment (loader = jinja_file_loader )
117+ prompt_template = jinja_env .get_template ('generate_prompt.jinja2' )
118+
116119 output_file = Path (__file__ ).parent / "ground_truth.jsonl"
117120
118121 qa : list [dict ] = []
119- for source in source_retriever ():
120- if len (qa ) > num_questions_total :
121- logger .info ("Generated enough questions already, stopping" )
122- break
122+ while len (qa ) < num_questions_total :
123+ sources = next (source_retriever ())
124+ previous_questions = [qa_pair ["question" ] for qa_pair in qa ]
123125 result = openai_client .chat .completions .create (
124126 model = model ,
125127 messages = [
126- {"role" : "system" , "content" : generate_prompt },
127- {"role" : "user" , "content" : json .dumps (source )},
128+ {"role" : "system" , "content" : prompt_template . render ( num_questions = num_questions_per_source , previous_questions = previous_questions ) },
129+ {"role" : "user" , "content" : json .dumps (sources )},
128130 ],
129- tools = [qa_pairs_tool (num_questions = 2 )],
131+ tools = [qa_pairs_tool (num_questions = num_questions_per_source )],
130132 )
131133 if not result .choices [0 ].message .tool_calls :
132134 logger .warning ("No tool calls found in response, skipping" )
133135 continue
134136 qa_pairs = json .loads (result .choices [0 ].message .tool_calls [0 ].function .arguments )["qa_list" ]
135137 qa_pairs = [{"question" : qa_pair ["question" ], "truth" : qa_pair ["answer" ]} for qa_pair in qa_pairs ]
138+ logger .info ("Received %d suggested questions" , len (qa_pairs ))
136139 qa .extend (qa_pairs )
137140
138141 logger .info ("Writing %d questions to %s" , num_questions_total , output_file )
@@ -145,8 +148,16 @@ def generate_ground_truth_data(num_questions_total: int, num_questions_per_sourc
145148
146149
147150if __name__ == "__main__" :
148- logging .basicConfig (level = logging .WARNING )
151+ logging .basicConfig (
152+ level = logging .WARNING , format = "%(message)s" , datefmt = "[%X]" , handlers = [RichHandler (rich_tracebacks = True )]
153+ )
149154 logger .setLevel (logging .INFO )
150- load_azd_env ()
155+ load_dotenv (".env" , override = True )
156+
157+ parser = argparse .ArgumentParser (description = "Run evaluation with OpenAI configuration." )
158+ parser .add_argument ("--numquestions" , type = int , help = "Specify the number of questions." , default = 50 )
159+ parser .add_argument ("--persource" , type = int , help = "Specify the number of questions per retrieved sources." , default = 5 )
160+
161+ args = parser .parse_args ()
151162
152- generate_ground_truth_data (num_questions_total = 10 )
163+ generate_ground_truth_data (num_questions_total = args . numquestions , num_questions_per_source = args . persource )
0 commit comments