Skip to content

Commit c180a19

Browse files
committed
new file
1 parent a387f4c commit c180a19

File tree

7 files changed

+409
-8
lines changed

7 files changed

+409
-8
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,5 @@ config.yml
167167
hydra_outputs/
168168
.commit0*
169169
.agent*
170-
docs/analysis*.md
170+
docs/analysis*.md
171+
.aider*

BRANCH

Whitespace-only changes.

agent/agents.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,101 @@ def run(self) -> AgentReturn:
4949
"""Start agent"""
5050
raise NotImplementedError
5151

52+
class AgentTeams(Agents):
53+
def __init__(self, max_iteration: int, model_name: str):
54+
super().__init__(max_iteration)
55+
self.model = Model(model_name)
56+
def run(
57+
self,
58+
message: str,
59+
test_cmd: str,
60+
lint_cmd: str,
61+
fnames: list[str],
62+
log_dir: Path,
63+
test_first: bool = False,
64+
) -> AgentReturn:
65+
"""Start agent team"""
66+
if test_cmd:
67+
auto_test = True
68+
else:
69+
auto_test = False
70+
if lint_cmd:
71+
auto_lint = True
72+
else:
73+
auto_lint = False
74+
log_dir = log_dir.resolve()
75+
log_dir.mkdir(parents=True, exist_ok=True)
76+
input_history_file = log_dir / ".team.input.history"
77+
chat_history_file = log_dir / ".team.chat.history.md"
78+
79+
# Set up logging
80+
log_file = log_dir / "team.log"
81+
logging.basicConfig(
82+
filename=log_file,
83+
level=logging.INFO,
84+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
85+
)
86+
87+
# Redirect print statements to the log file
88+
sys.stdout = open(log_file, "a")
89+
sys.stderr = open(log_file, "a")
90+
91+
# Log the message
92+
agent_message_log_file = log_dir / "agent_message.log"
93+
with open(agent_message_log_file, "a") as f:
94+
f.write(f"Message Sent: {message}\n\n")
95+
96+
# Configure httpx and backoff logging
97+
handle_logging("httpx", log_file)
98+
handle_logging("backoff", log_file)
99+
100+
io = InputOutput(
101+
yes=True,
102+
input_history_file=input_history_file,
103+
chat_history_file=chat_history_file,
104+
)
105+
manager = Coder.create(
106+
main_model=self.model,
107+
fnames=fnames,
108+
auto_lint=auto_lint,
109+
auto_test=auto_test,
110+
lint_cmds={"python": lint_cmd},
111+
test_cmd=test_cmd,
112+
io=io,
113+
)
114+
manager.max_reflection = 1
115+
manager.stream = True
116+
117+
# Run the agent
118+
manager_message = "First, add every file in the repo to your conversation. Second, write a plan of attack to implement the entire repo and return this plan."
119+
manager.run(manager_message)
120+
121+
122+
coder = Coder.create(
123+
main_model=self.model,
124+
fnames=fnames,
125+
auto_lint=auto_lint,
126+
auto_test=auto_test,
127+
lint_cmds={"python": lint_cmd},
128+
test_cmd=test_cmd,
129+
io=io,
130+
)
131+
coder.max_reflection = self.max_iteration
132+
coder.stream = True
133+
134+
# Run the agent
135+
with open(chat_history_file, 'r', encoding='utf-8') as file:
136+
plan = file.read()
137+
coder_message = "follow this implementation plan: "+plan
138+
coder.run(coder_message)
139+
140+
sys.stdout.close()
141+
sys.stderr.close()
142+
# Restore original stdout and stderr
143+
sys.stdout = sys.__stdout__
144+
sys.stderr = sys.__stderr__
145+
146+
return AgentReturn(log_file)
52147

53148
class AiderAgents(Agents):
54149
def __init__(self, max_iteration: int, model_name: str):

agent/run_agent.py

Lines changed: 171 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from typing import Any
23
import yaml
34
import multiprocessing
45
from datasets import load_dataset
@@ -12,6 +13,7 @@
1213
read_yaml_config,
1314
)
1415
from agent.agents import AiderAgents
16+
from agent.agents import AgentTeams
1517
from typing import Optional, Type, cast
1618
from types import TracebackType
1719
from agent.class_types import AgentConfig
@@ -42,8 +44,44 @@ def __exit__(
4244
) -> None:
4345
os.chdir(self.cwd)
4446

45-
4647
def run_agent_for_repo(
48+
commit0_config: Any,
49+
repo_base_dir: str,
50+
agent_config: AgentConfig,
51+
example: RepoInstance,
52+
update_queue: multiprocessing.Queue,
53+
branch: str,
54+
override_previous_changes: bool = False,
55+
backend: str = "modal",
56+
log_dir: str = str(RUN_AGENT_LOG_DIR.resolve()),
57+
) -> None:
58+
if agent_config.agent_name == "aider":
59+
run_aider_for_repo(
60+
commit0_config["base_dir"],
61+
agent_config,
62+
cast(RepoInstance, example),
63+
update_queue,
64+
branch,
65+
override_previous_changes,
66+
backend,
67+
log_dir,
68+
)
69+
elif agent_config.agent_name == "aider_team":
70+
run_team_for_repo(
71+
commit0_config["base_dir"],
72+
agent_config,
73+
cast(RepoInstance, example),
74+
update_queue,
75+
branch,
76+
override_previous_changes,
77+
backend,
78+
log_dir,
79+
)
80+
else:
81+
raise NotImplementedError(
82+
f"{agent_config.agent_name} is not implemented; please add your implementations in baselines/agents.py."
83+
)
84+
def run_aider_for_repo(
4785
repo_base_dir: str,
4886
agent_config: AgentConfig,
4987
example: RepoInstance,
@@ -53,6 +91,8 @@ def run_agent_for_repo(
5391
backend: str = "modal",
5492
log_dir: str = str(RUN_AGENT_LOG_DIR.resolve()),
5593
) -> None:
94+
95+
agent = AiderAgents(agent_config.max_iteration, agent_config.model_name)
5696
"""Run Aider for a given repository."""
5797
# get repo info
5898
_, repo_name = example["repo"].split("/")
@@ -74,12 +114,7 @@ def run_agent_for_repo(
74114
f"{repo_path} is not a git repo. Check if base_dir is correctly specified."
75115
)
76116

77-
if agent_config.agent_name == "aider":
78-
agent = AiderAgents(agent_config.max_iteration, agent_config.model_name)
79-
else:
80-
raise NotImplementedError(
81-
f"{agent_config.agent_name} is not implemented; please add your implementations in baselines/agents.py."
82-
)
117+
83118

84119
# # if branch_name is not provided, create a new branch name based on agent_config
85120
# if branch is None:
@@ -160,6 +195,7 @@ def run_agent_for_repo(
160195
update_queue.put(
161196
("start_repo", (original_repo_name, len(target_edit_files)))
162197
)
198+
163199
for f in target_edit_files:
164200
update_queue.put(("set_current_file", (repo_name, f)))
165201
dependencies = import_dependencies[f]
@@ -175,6 +211,132 @@ def run_agent_for_repo(
175211
)
176212
)
177213
update_queue.put(("finish_repo", original_repo_name))
214+
def run_team_for_repo(
215+
repo_base_dir: str,
216+
agent_config: AgentConfig,
217+
example: RepoInstance,
218+
update_queue: multiprocessing.Queue,
219+
branch: str,
220+
override_previous_changes: bool = False,
221+
backend: str = "modal",
222+
log_dir: str = str(RUN_AGENT_LOG_DIR.resolve()),
223+
) -> None:
224+
"""Run Aider for a given repository."""
225+
# get repo info
226+
_, repo_name = example["repo"].split("/")
227+
228+
# before starting, display all information to terminal
229+
original_repo_name = repo_name
230+
update_queue.put(("start_repo", (original_repo_name, 0)))
231+
232+
# repo_name = repo_name.lower()
233+
# repo_name = repo_name.replace(".", "-")
234+
235+
repo_path = os.path.join(repo_base_dir, repo_name)
236+
repo_path = os.path.abspath(repo_path)
237+
238+
try:
239+
local_repo = Repo(repo_path)
240+
except Exception:
241+
raise Exception(
242+
f"{repo_path} is not a git repo. Check if base_dir is correctly specified."
243+
)
244+
245+
manager = AiderAgents(1, agent_config.model_name)
246+
coder = AiderAgents(agent_config.max_iteration, agent_config.model_name)
247+
248+
249+
# # if branch_name is not provided, create a new branch name based on agent_config
250+
# if branch is None:
251+
# branch = args2string(agent_config)
252+
create_branch(local_repo, branch, example["base_commit"])
253+
254+
# in cases where the latest commit of branch is not commit 0
255+
# set it back to commit 0
256+
latest_commit = local_repo.commit(branch)
257+
if latest_commit.hexsha != example["base_commit"] and override_previous_changes:
258+
local_repo.git.reset("--hard", example["base_commit"])
259+
260+
target_edit_files, import_dependencies = get_target_edit_files(
261+
local_repo,
262+
example["src_dir"],
263+
example["test"]["test_dir"],
264+
str(latest_commit),
265+
example["reference_commit"],
266+
)
267+
268+
269+
# Call the commit0 get-tests command to retrieve test files
270+
test_files_str = get_tests(repo_name, verbose=0)
271+
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))
272+
273+
# prepare the log dir
274+
experiment_log_dir = (
275+
Path(log_dir)
276+
/ repo_name
277+
/ branch
278+
/ datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
279+
)
280+
experiment_log_dir.mkdir(parents=True, exist_ok=True)
281+
282+
# write agent_config to .agent.yaml in the log_dir for record
283+
agent_config_log_file = experiment_log_dir / ".agent.yaml"
284+
with open(agent_config_log_file, "w") as agent_config_file:
285+
yaml.dump(agent_config, agent_config_file)
286+
287+
# TODO: make this path more general
288+
commit0_dot_file_path = str(Path(repo_path).parent.parent / ".commit0.yaml")
289+
manager_message = "Write a concise plan of attack to implement the entire repo, but don't actually do any coding. The plan should not include any reccommendations to add files and should be a maximum of 500 words."
290+
291+
with DirContext(repo_path):
292+
if agent_config is None:
293+
raise ValueError("Invalid input")
294+
else:
295+
# when unit test feedback is not available, iterate over target files to edit
296+
297+
update_queue.put(
298+
("start_repo", (original_repo_name, len(target_edit_files)))
299+
)
300+
301+
for f in target_edit_files:
302+
update_queue.put(("set_current_file", (repo_name, f)))
303+
dependencies = import_dependencies[f]
304+
file_name = "all"
305+
file_log_dir = experiment_log_dir / file_name
306+
lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info)
307+
308+
309+
agent_return = manager.run(manager_message, "", lint_cmd, target_edit_files, file_log_dir)
310+
with open(agent_return.log_file, 'r', encoding='utf-8') as file:
311+
plan = file.read()
312+
coder_message = "follow this implementation plan: "+plan
313+
314+
agent_return = coder.run(coder_message, "", lint_cmd, target_edit_files, file_log_dir)
315+
316+
# for f in target_edit_files:
317+
# update_queue.put(("set_current_file", (repo_name, f)))
318+
# dependencies = import_dependencies[f]
319+
# message = update_message_with_dependencies(coder_message, dependencies)
320+
# file_name = f.replace(".py", "").replace("/", "__")
321+
# file_log_dir = experiment_log_dir / file_name
322+
# lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info)
323+
# agent_return = coder.run(message, "", lint_cmd, [f], file_log_dir)
324+
# update_queue.put(
325+
# (
326+
# "update_money_display",
327+
# (repo_name, file_name, agent_return.last_cost),
328+
# )
329+
# )
330+
update_queue.put(
331+
(
332+
"update_money_display",
333+
(repo_name, file_name, agent_return.last_cost),
334+
)
335+
)
336+
337+
338+
339+
update_queue.put(("finish_repo", original_repo_name))
178340

179341

180342
def run_agent(
@@ -248,6 +410,7 @@ def run_agent(
248410
result = pool.apply_async(
249411
run_agent_for_repo,
250412
args=(
413+
commit0_config,
251414
commit0_config["base_dir"],
252415
agent_config,
253416
cast(RepoInstance, example),
@@ -256,6 +419,7 @@ def run_agent(
256419
override_previous_changes,
257420
backend,
258421
log_dir,
422+
259423
),
260424
)
261425
results.append(result)

0 commit comments

Comments
 (0)