99from import_deps import ModuleSet
1010from graphlib import TopologicalSorter , CycleError
1111import yaml
12-
12+ from rank_bm25 import BM25Okapi
1313from agent .class_types import AgentConfig
14+ import subprocess
1415
1516PROMPT_HEADER = ">>> Here is the Task:\n "
17+ FUNCTION_HEADER = "\n \n >>> Here are all functions in the file, complete the implementations for all functions (i.e., those with pass statements):\n "
1618REFERENCE_HEADER = "\n \n >>> Here is the Reference for you to finish the task:\n "
1719REPO_INFO_HEADER = "\n \n >>> Here is the Repository Information:\n "
1820UNIT_TESTS_INFO_HEADER = "\n \n >>> Here are the Unit Tests Information:\n "
1921LINT_INFO_HEADER = "\n \n >>> Here is the Lint Information:\n "
2022SPEC_INFO_HEADER = "\n \n >>> Here is the Specification Information:\n "
2123IMPORT_DEPENDENCIES_HEADER = "\n \n >>> Here are the Import Dependencies:\n "
24+ FUNCTION_BY_FUNCTION_HEADER = """"\n Your task is to implement function {unimplemented_functions} by replacing the pass statement with actual functional code.
25+ Please note that there could be multiple occurrences of {unimplemented_functions}, and you need to implement them all.
26+ Do not change the names of existing functions or classes, as they may be referenced from other code like unit tests, etc.
27+ When you generate code, you must maintain the original formatting of the function stubs (such as whitespaces), otherwise we will not able to search/replace blocks for code modifications, and therefore you will receive a score of 0 for your generated code."""
2228# prefix components:
2329space = " "
2430branch = "│ "
@@ -123,6 +129,32 @@ def get_file_info(file_path: Path, prefix: str = "") -> str:
123129 return "\n " .join (filter (None , tree_string ))
124130
125131
132+ def get_unimplemented_functions (file_path : Path ) -> List [str ]:
133+ """Get all the functions in a file."""
134+ with open (file_path , "r" ) as f :
135+ content = f .read ()
136+
137+ # Find all function definitions with their bodies
138+ pattern = r"def\s+(\w+)\s*\([^)]*\)[^:]*:(?:\s*(?:'''[\s\S]*?'''|\"\"\"[\s\S]*?\"\"\"))?\s*((?:(?!\ndef\s+).)*?)(?=\s*def\s+|\s*$)"
139+ matches = re .finditer (pattern , content , re .DOTALL )
140+
141+ # Keep only functions that have just 'pass'
142+ # List to store unimplemented function definitions
143+ unimplemented_functions = []
144+ for match in matches :
145+ func_name = match .group (1 )
146+ func_body = match .group (2 ).strip ()
147+ # Check if function only contains 'pass' statement
148+ if "pass" in func_body :
149+ unimplemented_functions .append (f"def { func_name } ()" )
150+ # # Find the full function definition using regex pattern
151+ # func_pattern = rf"def\s+{func_name}\s*\([^)]*\)[^:]*:"
152+ # func_match = re.search(func_pattern, content)
153+ # if func_match:
154+ # unimplemented.append(func_match.group(0))
155+ return unimplemented_functions
156+
157+
126158def collect_test_files (directory : str ) -> list [str ]:
127159 """Collect all the test files in the directory."""
128160 test_files = []
@@ -265,9 +297,9 @@ def get_target_edit_files(
265297 raise ValueError (
266298 "topological_sort_files should not be longer than filtered_files"
267299 )
268- assert len (topological_sort_files ) == len (
269- filtered_files
270- ), "all files should be included"
300+ assert len (topological_sort_files ) == len (filtered_files ), (
301+ "all files should be included"
302+ )
271303
272304 # change to latest commit
273305 local_repo .git .checkout (branch )
@@ -324,9 +356,9 @@ def get_target_edit_files_from_patch(
324356 raise ValueError (
325357 "topological_sort_files should not be longer than target_files_list"
326358 )
327- assert len (topological_sort_files ) == len (
328- target_files_list
329- ), "all files should be included"
359+ assert len (topological_sort_files ) == len (target_files_list ), (
360+ "all files should be included"
361+ )
330362
331363 topological_sort_files = [
332364 file .replace (working_dir , "" ).lstrip ("/" ) for file in topological_sort_files
@@ -347,6 +379,7 @@ def get_message(
347379 agent_config : AgentConfig ,
348380 repo_path : str ,
349381 test_files : list [str ] | None = None ,
382+ input_file : str | None = None ,
350383) -> str :
351384 """Get the message to Aider."""
352385 prompt = f"{ PROMPT_HEADER } " + agent_config .user_prompt
@@ -383,11 +416,11 @@ def get_message(
383416 with bz2 .open ("spec.pdf.bz2" , "rb" ) as in_file :
384417 with open ("spec.pdf" , "wb" ) as out_file :
385418 out_file .write (in_file .read ())
386- spec_info = (
387- f" \n { SPEC_INFO_HEADER } "
388- + get_specification ( specification_pdf_path = Path ( repo_path , "spec.pdf" ))[
389- : agent_config . max_spec_info_length
390- ]
419+ spec_info = f" \n { SPEC_INFO_HEADER } " + get_specification (
420+ specification_pdf_path = Path ( repo_path , "spec.pdf" ),
421+ use_retrieval = True ,
422+ query = input_file if input_file else "" ,
423+ top_k = 10 ,
391424 )
392425 else :
393426 spec_info = ""
@@ -397,6 +430,39 @@ def get_message(
397430 return message_to_agent
398431
399432
433+ def get_message_function_by_function (
434+ agent_config : AgentConfig ,
435+ repo_path : str ,
436+ input_file : str ,
437+ test_files : list [str ] | None = None ,
438+ ) -> list [str ]:
439+ """Get the message to Aider."""
440+ context = get_message (agent_config , repo_path , test_files )
441+
442+ if agent_config .implementation_strategy == "module_by_module" :
443+ function_info = []
444+ elif agent_config .implementation_strategy == "function_by_function" :
445+ function_info = []
446+ unimplemented_functions = get_unimplemented_functions (
447+ file_path = Path (os .path .join (repo_path , input_file ))
448+ )
449+ # Get the original function stubs and filter out implemented functions
450+ for i in range (len (unimplemented_functions )):
451+ function_info .append (
452+ FUNCTION_BY_FUNCTION_HEADER .format (
453+ unimplemented_functions = unimplemented_functions [i ]
454+ )
455+ )
456+ else :
457+ raise ValueError (
458+ f"Invalid implementation strategy: { agent_config .implementation_strategy } "
459+ )
460+
461+ messages_to_agent = [context + uf for uf in unimplemented_functions ]
462+
463+ return messages_to_agent
464+
465+
400466def update_message_with_dependencies (message : str , dependencies : list [str ]) -> str :
401467 """Update the message with the dependencies."""
402468 if len (dependencies ) == 0 :
@@ -411,19 +477,43 @@ def update_message_with_dependencies(message: str, dependencies: list[str]) -> s
411477 return message
412478
413479
414- def get_specification (specification_pdf_path : Path ) -> str :
480+ def get_specification (
481+ specification_pdf_path : Path ,
482+ use_retrieval : bool = True ,
483+ query : str = "" ,
484+ top_k : int = 20 ,
485+ ) -> str :
415486 """Get the reference for a given specification PDF path."""
416487 # TODO: after pdf_to_text is available, use it to extract the text from the PDF
417488 # Open the specified PDF file
489+
418490 document = fitz .open (specification_pdf_path )
419- text = ""
491+ corpus = []
420492
493+ # current_trunk = ""
421494 # Iterate through the pages
422495 for page_num in range (len (document )):
423496 page = document .load_page (page_num ) # loads the specified page
424- text += page .get_text () # type: ignore
425497
426- return text
498+ current_page_text = page .get_text () # type: ignore
499+ # Cut page text into chunks of 1000 characters
500+ text_chunks = [
501+ current_page_text [i : i + 1000 ]
502+ for i in range (0 , len (current_page_text ), 1000 )
503+ ]
504+ corpus .extend (text_chunks )
505+ # corpus.append(page.get_text()) # type: ignore
506+ if not use_retrieval :
507+ return "\n " .join (corpus )
508+
509+ assert query != "" , "query should not be empty"
510+ query = open (query ).read ()
511+ tokenized_corpus = [doc .split (" " ) for doc in corpus ]
512+ bm25 = BM25Okapi (tokenized_corpus )
513+ doc_scores = bm25 .get_scores (query )
514+ sorted_doc_scores = sorted (enumerate (doc_scores ), key = lambda x : x [1 ], reverse = True )
515+ sorted_doc_indices = [i for i , _ in sorted_doc_scores ]
516+ return "\n " .join (corpus [i ] for i in sorted_doc_indices [:top_k ])
427517
428518
429519def create_branch (repo : git .Repo , branch : str , from_commit : str ) -> None :
@@ -486,6 +576,21 @@ def get_changed_files_from_commits(
486576 return []
487577
488578
579+ def run_eval_after_each_commit (
580+ branch : str , backend : str , commit0_config_file : str , repo_name : str
581+ ) -> str :
582+ """Run the eval command after each commit."""
583+ eval_cmd = f"python -m commit0 evaluate --branch { branch } --backend { backend } --commit0-config-file { commit0_config_file } --timeout 100"
584+ try :
585+ result = subprocess .run (
586+ eval_cmd , shell = True , capture_output = True , text = True , check = True
587+ )
588+ return result .stdout
589+ except subprocess .CalledProcessError as e :
590+ print (f"Error running eval command: { e } " )
591+ return e .stdout if e .stdout else str (e )
592+
593+
489594def args2string (agent_config : AgentConfig ) -> str :
490595 """Converts specific fields from an `AgentConfig` object into a formatted string.
491596
0 commit comments