Skip to content

Commit 51ed26c

Browse files
MaxWhitton25EugeneChoi4
authored andcommitted
new file
1 parent 152a327 commit 51ed26c

File tree

7 files changed

+428
-10
lines changed

7 files changed

+428
-10
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,5 @@ config.yml
167167
hydra_outputs/
168168
.commit0*
169169
.agent*
170-
docs/analysis*.md
170+
docs/analysis*.md
171+
.aider*

BRANCH

Whitespace-only changes.

agent/agents.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,101 @@ def get_money_cost(self) -> float:
5757
last_cost = float(match.group(1))
5858
return last_cost
5959

60+
class AgentTeams(Agents):
61+
def __init__(self, max_iteration: int, model_name: str):
62+
super().__init__(max_iteration)
63+
self.model = Model(model_name)
64+
def run(
65+
self,
66+
message: str,
67+
test_cmd: str,
68+
lint_cmd: str,
69+
fnames: list[str],
70+
log_dir: Path,
71+
test_first: bool = False,
72+
) -> AgentReturn:
73+
"""Start agent team"""
74+
if test_cmd:
75+
auto_test = True
76+
else:
77+
auto_test = False
78+
if lint_cmd:
79+
auto_lint = True
80+
else:
81+
auto_lint = False
82+
log_dir = log_dir.resolve()
83+
log_dir.mkdir(parents=True, exist_ok=True)
84+
input_history_file = log_dir / ".team.input.history"
85+
chat_history_file = log_dir / ".team.chat.history.md"
86+
87+
# Set up logging
88+
log_file = log_dir / "team.log"
89+
logging.basicConfig(
90+
filename=log_file,
91+
level=logging.INFO,
92+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
93+
)
94+
95+
# Redirect print statements to the log file
96+
sys.stdout = open(log_file, "a")
97+
sys.stderr = open(log_file, "a")
98+
99+
# Log the message
100+
agent_message_log_file = log_dir / "agent_message.log"
101+
with open(agent_message_log_file, "a") as f:
102+
f.write(f"Message Sent: {message}\n\n")
103+
104+
# Configure httpx and backoff logging
105+
handle_logging("httpx", log_file)
106+
handle_logging("backoff", log_file)
107+
108+
io = InputOutput(
109+
yes=True,
110+
input_history_file=input_history_file,
111+
chat_history_file=chat_history_file,
112+
)
113+
manager = Coder.create(
114+
main_model=self.model,
115+
fnames=fnames,
116+
auto_lint=auto_lint,
117+
auto_test=auto_test,
118+
lint_cmds={"python": lint_cmd},
119+
test_cmd=test_cmd,
120+
io=io,
121+
)
122+
manager.max_reflection = 1
123+
manager.stream = True
124+
125+
# Run the agent
126+
manager_message = "First, add every file in the repo to your conversation. Second, write a plan of attack to implement the entire repo and return this plan."
127+
manager.run(manager_message)
128+
129+
130+
coder = Coder.create(
131+
main_model=self.model,
132+
fnames=fnames,
133+
auto_lint=auto_lint,
134+
auto_test=auto_test,
135+
lint_cmds={"python": lint_cmd},
136+
test_cmd=test_cmd,
137+
io=io,
138+
)
139+
coder.max_reflection = self.max_iteration
140+
coder.stream = True
141+
142+
# Run the agent
143+
with open(chat_history_file, 'r', encoding='utf-8') as file:
144+
plan = file.read()
145+
coder_message = "follow this implementation plan: "+plan
146+
coder.run(coder_message)
147+
148+
sys.stdout.close()
149+
sys.stderr.close()
150+
# Restore original stdout and stderr
151+
sys.stdout = sys.__stdout__
152+
sys.stderr = sys.__stderr__
153+
154+
return AgentReturn(log_file)
60155

61156
class AiderAgents(Agents):
62157
def __init__(self, max_iteration: int, model_name: str):

agent/run_agent.py

Lines changed: 190 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from typing import Any
23
import yaml
34
import multiprocessing
45
from datasets import load_dataset
@@ -14,6 +15,7 @@
1415
)
1516
import subprocess
1617
from agent.agents import AiderAgents
18+
from agent.agents import AgentTeams
1719
from typing import Optional, Type, cast
1820
from types import TracebackType
1921
from agent.class_types import AgentConfig
@@ -44,7 +46,6 @@ def __exit__(
4446
) -> None:
4547
os.chdir(self.cwd)
4648

47-
4849
def run_agent_for_repo(
4950
repo_base_dir: str,
5051
agent_config: AgentConfig,
@@ -56,9 +57,53 @@ def run_agent_for_repo(
5657
log_dir: str = str(RUN_AGENT_LOG_DIR.resolve()),
5758
commit0_config_file: str = "",
5859
) -> None:
59-
"""Run Aider for a given repository."""
60-
# get repo info
60+
6161
commit0_config = read_commit0_dot_file(commit0_config_file)
62+
if agent_config.agent_name == "aider":
63+
run_aider_for_repo(
64+
commit0_config,
65+
commit0_config["base_dir"],
66+
agent_config,
67+
cast(RepoInstance, example),
68+
update_queue,
69+
branch,
70+
override_previous_changes,
71+
backend,
72+
log_dir,
73+
commit0_config_file,
74+
)
75+
elif agent_config.agent_name == "aider_team":
76+
run_team_for_repo(
77+
commit0_config,
78+
commit0_config["base_dir"],
79+
agent_config,
80+
cast(RepoInstance, example),
81+
update_queue,
82+
branch,
83+
override_previous_changes,
84+
backend,
85+
log_dir,
86+
commit0_config_file,
87+
)
88+
else:
89+
raise NotImplementedError(
90+
f"{agent_config.agent_name} is not implemented; please add your implementations in baselines/agents.py."
91+
)
92+
93+
def run_aider_for_repo(
94+
commit0_config: Any,
95+
repo_base_dir: str,
96+
agent_config: AgentConfig,
97+
example: RepoInstance,
98+
update_queue: multiprocessing.Queue,
99+
branch: str,
100+
override_previous_changes: bool = False,
101+
backend: str = "modal",
102+
log_dir: str = str(RUN_AGENT_LOG_DIR.resolve()),
103+
commit0_config_file: str = "",
104+
) -> None:
105+
106+
"""Run Aider for a given repository."""
62107

63108
assert "commit0" in commit0_config["dataset_name"]
64109
_, repo_name = example["repo"].split("/")
@@ -79,12 +124,7 @@ def run_agent_for_repo(
79124
f"{repo_path} is not a git repo. Check if base_dir is correctly specified."
80125
)
81126

82-
if agent_config.agent_name == "aider":
83-
agent = AiderAgents(agent_config.max_iteration, agent_config.model_name)
84-
else:
85-
raise NotImplementedError(
86-
f"{agent_config.agent_name} is not implemented; please add your implementations in baselines/agents.py."
87-
)
127+
agent = AiderAgents(agent_config.max_iteration, agent_config.model_name)
88128

89129
# Check if there are changes in the current branch
90130
if local_repo.is_dirty():
@@ -219,6 +259,146 @@ def run_agent_for_repo(
219259
)
220260
update_queue.put(("finish_repo", repo_name))
221261

262+
def run_team_for_repo(
263+
commit0_config: Any,
264+
repo_base_dir: str,
265+
agent_config: AgentConfig,
266+
example: RepoInstance,
267+
update_queue: multiprocessing.Queue,
268+
branch: str,
269+
override_previous_changes: bool = False,
270+
backend: str = "modal",
271+
log_dir: str = str(RUN_AGENT_LOG_DIR.resolve()),
272+
commit0_config_file: str = "",
273+
) -> None:
274+
"""Run Aider Team for a given repository."""
275+
# get repo info
276+
assert "commit0" in commit0_config["dataset_name"]
277+
_, repo_name = example["repo"].split("/")
278+
279+
# before starting, display all information to terminal
280+
update_queue.put(("start_repo", (repo_name, 0)))
281+
282+
# repo_name = repo_name.lower()
283+
# repo_name = repo_name.replace(".", "-")
284+
285+
repo_path = os.path.join(repo_base_dir, repo_name)
286+
repo_path = os.path.abspath(repo_path)
287+
288+
try:
289+
local_repo = Repo(repo_path)
290+
except Exception:
291+
raise Exception(
292+
f"{repo_path} is not a git repo. Check if base_dir is correctly specified."
293+
)
294+
295+
manager = AiderAgents(1, agent_config.model_name)
296+
coder = AiderAgents(agent_config.max_iteration, agent_config.model_name)
297+
298+
# Check if there are changes in the current branch
299+
if local_repo.is_dirty():
300+
# Stage all changes
301+
local_repo.git.add(A=True)
302+
# Commit changes with the message "left from last change"
303+
local_repo.index.commit("left from last change")
304+
305+
# # if branch_name is not provided, create a new branch name based on agent_config
306+
# if branch is None:
307+
# branch = args2string(agent_config)
308+
create_branch(local_repo, branch, example["base_commit"])
309+
310+
# in cases where the latest commit of branch is not commit 0
311+
# set it back to commit 0
312+
latest_commit = local_repo.commit(branch)
313+
if latest_commit.hexsha != example["base_commit"] and override_previous_changes:
314+
local_repo.git.reset("--hard", example["base_commit"])
315+
316+
# get target files to edit and test files to run
317+
target_edit_files, import_dependencies = get_target_edit_files(
318+
local_repo,
319+
example["src_dir"],
320+
example["test"]["test_dir"],
321+
branch,
322+
example["reference_commit"],
323+
agent_config.use_topo_sort_dependencies,
324+
)
325+
326+
lint_files = get_changed_files_from_commits(
327+
local_repo, "HEAD", example["base_commit"]
328+
)
329+
# Call the commit0 get-tests command to retrieve test files
330+
test_files_str = get_tests(repo_name, verbose=0)
331+
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))
332+
333+
# prepare the log dir
334+
experiment_log_dir = (
335+
Path(log_dir)
336+
/ repo_name
337+
/ branch
338+
/ datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
339+
)
340+
experiment_log_dir.mkdir(parents=True, exist_ok=True)
341+
342+
# write agent_config to .agent.yaml in the log_dir for record
343+
agent_config_log_file = experiment_log_dir / ".agent.yaml"
344+
with open(agent_config_log_file, "w") as agent_config_file:
345+
yaml.dump(agent_config, agent_config_file)
346+
347+
manager_message = "Write a concise plan of attack to implement the entire repo, but don't actually do any coding. The plan should not include any reccommendations to add files and should be a maximum of 500 words."
348+
349+
with DirContext(repo_path):
350+
if agent_config is None:
351+
raise ValueError("Invalid input")
352+
else:
353+
# when unit test feedback is not available, iterate over target files to edit
354+
355+
update_queue.put(
356+
("start_repo", (repo_name, len(target_edit_files)))
357+
)
358+
359+
#TODO: add support for unit test / lint feedback
360+
361+
for f in target_edit_files:
362+
update_queue.put(("set_current_file", (repo_name, f)))
363+
dependencies = import_dependencies[f]
364+
file_name = "all"
365+
file_log_dir = experiment_log_dir / file_name
366+
lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info)
367+
368+
369+
agent_return = manager.run(manager_message, "", lint_cmd, target_edit_files, file_log_dir)
370+
with open(agent_return.log_file, 'r', encoding='utf-8') as file:
371+
plan = file.read()
372+
coder_message = "follow this implementation plan: "+plan
373+
374+
agent_return = coder.run(coder_message, "", lint_cmd, target_edit_files, file_log_dir)
375+
376+
# for f in target_edit_files:
377+
# update_queue.put(("set_current_file", (repo_name, f)))
378+
# dependencies = import_dependencies[f]
379+
# message = update_message_with_dependencies(coder_message, dependencies)
380+
# file_name = f.replace(".py", "").replace("/", "__")
381+
# file_log_dir = experiment_log_dir / file_name
382+
# lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info)
383+
# agent_return = coder.run(message, "", lint_cmd, [f], file_log_dir)
384+
# update_queue.put(
385+
# (
386+
# "update_money_display",
387+
# (repo_name, file_name, agent_return.last_cost),
388+
# )
389+
# )
390+
update_queue.put(
391+
(
392+
"update_money_display",
393+
(repo_name, file_name, agent_return.last_cost),
394+
)
395+
)
396+
397+
398+
399+
update_queue.put(("finish_repo", repo_name))
400+
401+
222402

223403
def run_agent(
224404
branch: str,
@@ -303,6 +483,7 @@ def run_agent(
303483
result = pool.apply_async(
304484
run_agent_for_repo,
305485
args=(
486+
commit0_config,
306487
commit0_config["base_dir"],
307488
agent_config,
308489
cast(RepoInstance, example),

0 commit comments

Comments
 (0)