11import os
2+ from typing import Any
23import yaml
34import multiprocessing
45from datasets import load_dataset
1415)
1516import subprocess
1617from agent .agents import AiderAgents
18+ from agent .agents import AgentTeams
1719from typing import Optional , Type , cast
1820from types import TracebackType
1921from agent .class_types import AgentConfig
@@ -44,7 +46,6 @@ def __exit__(
4446 ) -> None :
4547 os .chdir (self .cwd )
4648
47-
4849def run_agent_for_repo (
4950 repo_base_dir : str ,
5051 agent_config : AgentConfig ,
@@ -56,9 +57,53 @@ def run_agent_for_repo(
5657 log_dir : str = str (RUN_AGENT_LOG_DIR .resolve ()),
5758 commit0_config_file : str = "" ,
5859) -> None :
59- """Run Aider for a given repository."""
60- # get repo info
60+
6161 commit0_config = read_commit0_dot_file (commit0_config_file )
62+ if agent_config .agent_name == "aider" :
63+ run_aider_for_repo (
64+ commit0_config ,
65+ commit0_config ["base_dir" ],
66+ agent_config ,
67+ cast (RepoInstance , example ),
68+ update_queue ,
69+ branch ,
70+ override_previous_changes ,
71+ backend ,
72+ log_dir ,
73+ commit0_config_file ,
74+ )
75+ elif agent_config .agent_name == "aider_team" :
76+ run_team_for_repo (
77+ commit0_config ,
78+ commit0_config ["base_dir" ],
79+ agent_config ,
80+ cast (RepoInstance , example ),
81+ update_queue ,
82+ branch ,
83+ override_previous_changes ,
84+ backend ,
85+ log_dir ,
86+ commit0_config_file ,
87+ )
88+ else :
89+ raise NotImplementedError (
90+ f"{ agent_config .agent_name } is not implemented; please add your implementations in baselines/agents.py."
91+ )
92+
93+ def run_aider_for_repo (
94+ commit0_config : Any ,
95+ repo_base_dir : str ,
96+ agent_config : AgentConfig ,
97+ example : RepoInstance ,
98+ update_queue : multiprocessing .Queue ,
99+ branch : str ,
100+ override_previous_changes : bool = False ,
101+ backend : str = "modal" ,
102+ log_dir : str = str (RUN_AGENT_LOG_DIR .resolve ()),
103+ commit0_config_file : str = "" ,
104+ ) -> None :
105+
106+ """Run Aider for a given repository."""
62107
63108 assert "commit0" in commit0_config ["dataset_name" ]
64109 _ , repo_name = example ["repo" ].split ("/" )
@@ -79,12 +124,7 @@ def run_agent_for_repo(
79124 f"{ repo_path } is not a git repo. Check if base_dir is correctly specified."
80125 )
81126
82- if agent_config .agent_name == "aider" :
83- agent = AiderAgents (agent_config .max_iteration , agent_config .model_name )
84- else :
85- raise NotImplementedError (
86- f"{ agent_config .agent_name } is not implemented; please add your implementations in baselines/agents.py."
87- )
127+ agent = AiderAgents (agent_config .max_iteration , agent_config .model_name )
88128
89129 # Check if there are changes in the current branch
90130 if local_repo .is_dirty ():
@@ -219,6 +259,146 @@ def run_agent_for_repo(
219259 )
220260 update_queue .put (("finish_repo" , repo_name ))
221261
262+ def run_team_for_repo (
263+ commit0_config : Any ,
264+ repo_base_dir : str ,
265+ agent_config : AgentConfig ,
266+ example : RepoInstance ,
267+ update_queue : multiprocessing .Queue ,
268+ branch : str ,
269+ override_previous_changes : bool = False ,
270+ backend : str = "modal" ,
271+ log_dir : str = str (RUN_AGENT_LOG_DIR .resolve ()),
272+ commit0_config_file : str = "" ,
273+ ) -> None :
274+ """Run Aider Team for a given repository."""
275+ # get repo info
276+ assert "commit0" in commit0_config ["dataset_name" ]
277+ _ , repo_name = example ["repo" ].split ("/" )
278+
279+ # before starting, display all information to terminal
280+ update_queue .put (("start_repo" , (repo_name , 0 )))
281+
282+ # repo_name = repo_name.lower()
283+ # repo_name = repo_name.replace(".", "-")
284+
285+ repo_path = os .path .join (repo_base_dir , repo_name )
286+ repo_path = os .path .abspath (repo_path )
287+
288+ try :
289+ local_repo = Repo (repo_path )
290+ except Exception :
291+ raise Exception (
292+ f"{ repo_path } is not a git repo. Check if base_dir is correctly specified."
293+ )
294+
295+ manager = AiderAgents (1 , agent_config .model_name )
296+ coder = AiderAgents (agent_config .max_iteration , agent_config .model_name )
297+
298+ # Check if there are changes in the current branch
299+ if local_repo .is_dirty ():
300+ # Stage all changes
301+ local_repo .git .add (A = True )
302+ # Commit changes with the message "left from last change"
303+ local_repo .index .commit ("left from last change" )
304+
305+ # # if branch_name is not provided, create a new branch name based on agent_config
306+ # if branch is None:
307+ # branch = args2string(agent_config)
308+ create_branch (local_repo , branch , example ["base_commit" ])
309+
310+ # in cases where the latest commit of branch is not commit 0
311+ # set it back to commit 0
312+ latest_commit = local_repo .commit (branch )
313+ if latest_commit .hexsha != example ["base_commit" ] and override_previous_changes :
314+ local_repo .git .reset ("--hard" , example ["base_commit" ])
315+
316+ # get target files to edit and test files to run
317+ target_edit_files , import_dependencies = get_target_edit_files (
318+ local_repo ,
319+ example ["src_dir" ],
320+ example ["test" ]["test_dir" ],
321+ branch ,
322+ example ["reference_commit" ],
323+ agent_config .use_topo_sort_dependencies ,
324+ )
325+
326+ lint_files = get_changed_files_from_commits (
327+ local_repo , "HEAD" , example ["base_commit" ]
328+ )
329+ # Call the commit0 get-tests command to retrieve test files
330+ test_files_str = get_tests (repo_name , verbose = 0 )
331+ test_files = sorted (list (set ([i .split (":" )[0 ] for i in test_files_str ])))
332+
333+ # prepare the log dir
334+ experiment_log_dir = (
335+ Path (log_dir )
336+ / repo_name
337+ / branch
338+ / datetime .now ().strftime ("%Y-%m-%d_%H-%M-%S" )
339+ )
340+ experiment_log_dir .mkdir (parents = True , exist_ok = True )
341+
342+ # write agent_config to .agent.yaml in the log_dir for record
343+ agent_config_log_file = experiment_log_dir / ".agent.yaml"
344+ with open (agent_config_log_file , "w" ) as agent_config_file :
345+ yaml .dump (agent_config , agent_config_file )
346+
347+ manager_message = "Write a concise plan of attack to implement the entire repo, but don't actually do any coding. The plan should not include any reccommendations to add files and should be a maximum of 500 words."
348+
349+ with DirContext (repo_path ):
350+ if agent_config is None :
351+ raise ValueError ("Invalid input" )
352+ else :
353+ # when unit test feedback is not available, iterate over target files to edit
354+
355+ update_queue .put (
356+ ("start_repo" , (repo_name , len (target_edit_files )))
357+ )
358+
359+ #TODO: add support for unit test / lint feedback
360+
361+ for f in target_edit_files :
362+ update_queue .put (("set_current_file" , (repo_name , f )))
363+ dependencies = import_dependencies [f ]
364+ file_name = "all"
365+ file_log_dir = experiment_log_dir / file_name
366+ lint_cmd = get_lint_cmd (repo_name , agent_config .use_lint_info )
367+
368+
369+ agent_return = manager .run (manager_message , "" , lint_cmd , target_edit_files , file_log_dir )
370+ with open (agent_return .log_file , 'r' , encoding = 'utf-8' ) as file :
371+ plan = file .read ()
372+ coder_message = "follow this implementation plan: " + plan
373+
374+ agent_return = coder .run (coder_message , "" , lint_cmd , target_edit_files , file_log_dir )
375+
376+ # for f in target_edit_files:
377+ # update_queue.put(("set_current_file", (repo_name, f)))
378+ # dependencies = import_dependencies[f]
379+ # message = update_message_with_dependencies(coder_message, dependencies)
380+ # file_name = f.replace(".py", "").replace("/", "__")
381+ # file_log_dir = experiment_log_dir / file_name
382+ # lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info)
383+ # agent_return = coder.run(message, "", lint_cmd, [f], file_log_dir)
384+ # update_queue.put(
385+ # (
386+ # "update_money_display",
387+ # (repo_name, file_name, agent_return.last_cost),
388+ # )
389+ # )
390+ update_queue .put (
391+ (
392+ "update_money_display" ,
393+ (repo_name , file_name , agent_return .last_cost ),
394+ )
395+ )
396+
397+
398+
399+ update_queue .put (("finish_repo" , repo_name ))
400+
401+
222402
223403def run_agent (
224404 branch : str ,
@@ -303,6 +483,7 @@ def run_agent(
303483 result = pool .apply_async (
304484 run_agent_for_repo ,
305485 args = (
486+ commit0_config ,
306487 commit0_config ["base_dir" ],
307488 agent_config ,
308489 cast (RepoInstance , example ),
0 commit comments