11import os
2+ from typing import Any
23import yaml
34import multiprocessing
45from datasets import load_dataset
1213 read_yaml_config ,
1314)
1415from agent .agents import AiderAgents
16+ from agent .agents import AgentTeams
1517from typing import Optional , Type , cast
1618from types import TracebackType
1719from agent .class_types import AgentConfig
@@ -42,8 +44,44 @@ def __exit__(
4244 ) -> None :
4345 os .chdir (self .cwd )
4446
45-
4647def run_agent_for_repo (
48+ commit0_config : Any ,
49+ repo_base_dir : str ,
50+ agent_config : AgentConfig ,
51+ example : RepoInstance ,
52+ update_queue : multiprocessing .Queue ,
53+ branch : str ,
54+ override_previous_changes : bool = False ,
55+ backend : str = "modal" ,
56+ log_dir : str = str (RUN_AGENT_LOG_DIR .resolve ()),
57+ ) -> None :
58+ if agent_config .agent_name == "aider" :
59+ run_aider_for_repo (
60+ commit0_config ["base_dir" ],
61+ agent_config ,
62+ cast (RepoInstance , example ),
63+ update_queue ,
64+ branch ,
65+ override_previous_changes ,
66+ backend ,
67+ log_dir ,
68+ )
69+ elif agent_config .agent_name == "aider_team" :
70+ run_team_for_repo (
71+ commit0_config ["base_dir" ],
72+ agent_config ,
73+ cast (RepoInstance , example ),
74+ update_queue ,
75+ branch ,
76+ override_previous_changes ,
77+ backend ,
78+ log_dir ,
79+ )
80+ else :
81+ raise NotImplementedError (
82+ f"{ agent_config .agent_name } is not implemented; please add your implementations in baselines/agents.py."
83+ )
84+ def run_aider_for_repo (
4785 repo_base_dir : str ,
4886 agent_config : AgentConfig ,
4987 example : RepoInstance ,
@@ -53,6 +91,8 @@ def run_agent_for_repo(
5391 backend : str = "modal" ,
5492 log_dir : str = str (RUN_AGENT_LOG_DIR .resolve ()),
5593) -> None :
94+
95+ agent = AiderAgents (agent_config .max_iteration , agent_config .model_name )
5696 """Run Aider for a given repository."""
5797 # get repo info
5898 _ , repo_name = example ["repo" ].split ("/" )
@@ -74,12 +114,7 @@ def run_agent_for_repo(
74114 f"{ repo_path } is not a git repo. Check if base_dir is correctly specified."
75115 )
76116
77- if agent_config .agent_name == "aider" :
78- agent = AiderAgents (agent_config .max_iteration , agent_config .model_name )
79- else :
80- raise NotImplementedError (
81- f"{ agent_config .agent_name } is not implemented; please add your implementations in baselines/agents.py."
82- )
117+
83118
84119 # # if branch_name is not provided, create a new branch name based on agent_config
85120 # if branch is None:
@@ -160,6 +195,7 @@ def run_agent_for_repo(
160195 update_queue .put (
161196 ("start_repo" , (original_repo_name , len (target_edit_files )))
162197 )
198+
163199 for f in target_edit_files :
164200 update_queue .put (("set_current_file" , (repo_name , f )))
165201 dependencies = import_dependencies [f ]
@@ -175,6 +211,132 @@ def run_agent_for_repo(
175211 )
176212 )
177213 update_queue .put (("finish_repo" , original_repo_name ))
214+ def run_team_for_repo (
215+ repo_base_dir : str ,
216+ agent_config : AgentConfig ,
217+ example : RepoInstance ,
218+ update_queue : multiprocessing .Queue ,
219+ branch : str ,
220+ override_previous_changes : bool = False ,
221+ backend : str = "modal" ,
222+ log_dir : str = str (RUN_AGENT_LOG_DIR .resolve ()),
223+ ) -> None :
224+ """Run Aider for a given repository."""
225+ # get repo info
226+ _ , repo_name = example ["repo" ].split ("/" )
227+
228+ # before starting, display all information to terminal
229+ original_repo_name = repo_name
230+ update_queue .put (("start_repo" , (original_repo_name , 0 )))
231+
232+ # repo_name = repo_name.lower()
233+ # repo_name = repo_name.replace(".", "-")
234+
235+ repo_path = os .path .join (repo_base_dir , repo_name )
236+ repo_path = os .path .abspath (repo_path )
237+
238+ try :
239+ local_repo = Repo (repo_path )
240+ except Exception :
241+ raise Exception (
242+ f"{ repo_path } is not a git repo. Check if base_dir is correctly specified."
243+ )
244+
245+ manager = AiderAgents (1 , agent_config .model_name )
246+ coder = AiderAgents (agent_config .max_iteration , agent_config .model_name )
247+
248+
249+ # # if branch_name is not provided, create a new branch name based on agent_config
250+ # if branch is None:
251+ # branch = args2string(agent_config)
252+ create_branch (local_repo , branch , example ["base_commit" ])
253+
254+ # in cases where the latest commit of branch is not commit 0
255+ # set it back to commit 0
256+ latest_commit = local_repo .commit (branch )
257+ if latest_commit .hexsha != example ["base_commit" ] and override_previous_changes :
258+ local_repo .git .reset ("--hard" , example ["base_commit" ])
259+
260+ target_edit_files , import_dependencies = get_target_edit_files (
261+ local_repo ,
262+ example ["src_dir" ],
263+ example ["test" ]["test_dir" ],
264+ str (latest_commit ),
265+ example ["reference_commit" ],
266+ )
267+
268+
269+ # Call the commit0 get-tests command to retrieve test files
270+ test_files_str = get_tests (repo_name , verbose = 0 )
271+ test_files = sorted (list (set ([i .split (":" )[0 ] for i in test_files_str ])))
272+
273+ # prepare the log dir
274+ experiment_log_dir = (
275+ Path (log_dir )
276+ / repo_name
277+ / branch
278+ / datetime .now ().strftime ("%Y-%m-%d_%H-%M-%S" )
279+ )
280+ experiment_log_dir .mkdir (parents = True , exist_ok = True )
281+
282+ # write agent_config to .agent.yaml in the log_dir for record
283+ agent_config_log_file = experiment_log_dir / ".agent.yaml"
284+ with open (agent_config_log_file , "w" ) as agent_config_file :
285+ yaml .dump (agent_config , agent_config_file )
286+
287+ # TODO: make this path more general
288+ commit0_dot_file_path = str (Path (repo_path ).parent .parent / ".commit0.yaml" )
289+ manager_message = "Write a concise plan of attack to implement the entire repo, but don't actually do any coding. The plan should not include any reccommendations to add files and should be a maximum of 500 words."
290+
291+ with DirContext (repo_path ):
292+ if agent_config is None :
293+ raise ValueError ("Invalid input" )
294+ else :
295+ # when unit test feedback is not available, iterate over target files to edit
296+
297+ update_queue .put (
298+ ("start_repo" , (original_repo_name , len (target_edit_files )))
299+ )
300+
301+ for f in target_edit_files :
302+ update_queue .put (("set_current_file" , (repo_name , f )))
303+ dependencies = import_dependencies [f ]
304+ file_name = "all"
305+ file_log_dir = experiment_log_dir / file_name
306+ lint_cmd = get_lint_cmd (repo_name , agent_config .use_lint_info )
307+
308+
309+ agent_return = manager .run (manager_message , "" , lint_cmd , target_edit_files , file_log_dir )
310+ with open (agent_return .log_file , 'r' , encoding = 'utf-8' ) as file :
311+ plan = file .read ()
312+ coder_message = "follow this implementation plan: " + plan
313+
314+ agent_return = coder .run (coder_message , "" , lint_cmd , target_edit_files , file_log_dir )
315+
316+ # for f in target_edit_files:
317+ # update_queue.put(("set_current_file", (repo_name, f)))
318+ # dependencies = import_dependencies[f]
319+ # message = update_message_with_dependencies(coder_message, dependencies)
320+ # file_name = f.replace(".py", "").replace("/", "__")
321+ # file_log_dir = experiment_log_dir / file_name
322+ # lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info)
323+ # agent_return = coder.run(message, "", lint_cmd, [f], file_log_dir)
324+ # update_queue.put(
325+ # (
326+ # "update_money_display",
327+ # (repo_name, file_name, agent_return.last_cost),
328+ # )
329+ # )
330+ update_queue .put (
331+ (
332+ "update_money_display" ,
333+ (repo_name , file_name , agent_return .last_cost ),
334+ )
335+ )
336+
337+
338+
339+ update_queue .put (("finish_repo" , original_repo_name ))
178340
179341
180342def run_agent (
@@ -248,6 +410,7 @@ def run_agent(
248410 result = pool .apply_async (
249411 run_agent_for_repo ,
250412 args = (
413+ commit0_config ,
251414 commit0_config ["base_dir" ],
252415 agent_config ,
253416 cast (RepoInstance , example ),
@@ -256,6 +419,7 @@ def run_agent(
256419 override_previous_changes ,
257420 backend ,
258421 log_dir ,
422+
259423 ),
260424 )
261425 results .append (result )
0 commit comments