Skip to content

Commit 245400d

Browse files
committed
Add working manager code
1 parent 96309de commit 245400d

File tree

3 files changed

+82
-41
lines changed

3 files changed

+82
-41
lines changed

agent/agent_utils.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -572,12 +572,27 @@ def read_yaml_config(config_file: str) -> dict:
572572
with open(config_file, "r") as f:
573573
return yaml.load(f, Loader=yaml.FullLoader)
574574

575-
576-
def parse_tasks(text: str) -> list[str]:
575+
def parse_tasks(text: str) -> list[tuple[str, str]]:
577576
"""Parse the tasks from the manager output."""
578-
tasks = []
579-
for line in text.strip().splitlines():
580-
if not line.strip()[0].isdigit():
581-
continue
582-
tasks.append(line.strip())
577+
tasks = []
578+
579+
# Extract the portion between PLAN_START and PLAN_END
580+
plan_match = re.search(r"PLAN_START(.*?)PLAN_END", text, re.DOTALL)
581+
if not plan_match:
582+
return tasks
583+
584+
# Get the plan content and split by patterns that indicate the start of a new task
585+
plan_text = plan_match.group(1).strip()
586+
task_blocks = re.split(r"\d+\.?\)?\s+", plan_text) # Split at task numbers
587+
588+
for block in task_blocks:
589+
if not block.strip():
590+
continue
591+
# Match the file name and task description
592+
match = re.search(r"([\w\-/\.0-9]+\.\w+):\s*(.*)", block.strip(), re.DOTALL)
593+
if match:
594+
file_name = match.group(1)
595+
description = re.sub(r'\s+', ' ', match.group(2)).strip() # Remove extra spaces
596+
tasks.append((file_name, description))
597+
583598
return tasks

agent/agents.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,24 @@ def get_money_cost(self) -> float:
5757
last_cost = float(match.group(1))
5858
return last_cost
5959

60+
class TeamsReturn(AgentReturn):
61+
def __init__(self, log_file: Path):
62+
super().__init__(log_file)
63+
self.last_cost = self.get_money_cost()
64+
65+
def get_money_cost(self) -> float:
66+
"""Get accumulated money cost from log file"""
67+
last_cost = 0.0
68+
with open(self.log_file, "r") as file:
69+
for line in file:
70+
if "Tokens:" in line and "Cost:" in line:
71+
match = re.search(
72+
r"Cost: \$\d+\.\d+ message, \$(\d+\.\d+) session", line
73+
)
74+
if match:
75+
last_cost = float(match.group(1))
76+
return last_cost
77+
6078
class AgentTeams(Agents):
6179
def __init__(self, max_iteration: int, model_name: str):
6280
super().__init__(max_iteration)
@@ -97,13 +115,14 @@ def run_manager(
97115
chat_history_file=chat_history_file,
98116
)
99117
manager = Coder.create(
118+
edit_format="ask",
100119
main_model=self.model,
101120
read_only_fnames=fnames,
102121
io=io,
103122
)
104123
manager.max_reflection = self.max_iteration
105124
manager.stream = True
106-
125+
107126
manager.run(message)
108127

109128
sys.stdout.close()
@@ -112,7 +131,7 @@ def run_manager(
112131
sys.stdout = sys.__stdout__
113132
sys.stderr = sys.__stderr__
114133

115-
return AgentReturn(log_file)
134+
return TeamsReturn(log_file)
116135

117136
def run(
118137
self,
@@ -163,7 +182,7 @@ def run(
163182
main_model=self.model,
164183

165184
#make the coder import files on its own for now
166-
fnames=[],
185+
fnames=fnames,
167186
auto_lint=auto_lint,
168187
auto_test=auto_test,
169188
lint_cmds={"python": lint_cmd},
@@ -180,7 +199,7 @@ def run(
180199
sys.stdout = sys.__stdout__
181200
sys.stderr = sys.__stderr__
182201

183-
return AgentReturn(log_file)
202+
return TeamsReturn(log_file)
184203

185204
class AiderAgents(Agents):
186205
def __init__(self, max_iteration: int, model_name: str):

agent/run_agent.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
parse_tasks,
1616
)
1717
import subprocess
18-
from agent.agents import AiderAgents
19-
from agent.agents import AgentTeams
18+
from agent.agents import AiderAgents, AgentTeams
2019
from typing import Optional, Type, cast
2120
from types import TracebackType
2221
from agent.class_types import AgentConfig
@@ -30,7 +29,6 @@
3029
import queue
3130
import time
3231

33-
3432
class DirContext:
3533
def __init__(self, d: str):
3634
self.dir = d
@@ -273,6 +271,7 @@ def run_team_for_repo(
273271
log_dir: str = str(RUN_AGENT_LOG_DIR.resolve()),
274272
commit0_config_file: str = "",
275273
) -> None:
274+
276275
"""Run Aider Team for a given repository."""
277276
# get repo info
278277
assert "commit0" in commit0_config["dataset_name"]
@@ -294,7 +293,7 @@ def run_team_for_repo(
294293
f"{repo_path} is not a git repo. Check if base_dir is correctly specified."
295294
)
296295

297-
manager = AgentTeams(agent_config.max_iteration, agent_config.model_name)
296+
manager = AgentTeams(1, agent_config.model_name)
298297
coder = AgentTeams(agent_config.max_iteration, agent_config.model_name)
299298

300299
# Check if there are changes in the current branch
@@ -347,7 +346,19 @@ def run_team_for_repo(
347346
yaml.dump(agent_config, agent_config_file)
348347

349348
# /ask will make aider not write any code, but only a plan
350-
manager_message = "/ask You are a manager in charge of writing a plan to complete the implementations for all functions (i.e., those with pass statements) and pass the unit tests. Write a concise plan of attack to implement the entire repo, but don't actually do any coding. Please output the plan in the format of a list of numbered steps. Each step should specify a file to edit and a high-level description of the change to make. For example, '1.) file.py: add a function to calculate the sum of two numbers'. Note that we only need to edit the files that contain functions with pass statements, ie. those in the current context. Give me only the plan, with no extraneous text."
349+
manager_message = f"""You are a manager in charge of writing a plan to complete the implementations for all functions (i.e., those with pass statements) and pass the unit tests. Write a concise plan of attack to implement the entire repo, but don't actually do any coding. Please output the plan in the format of a list of numbered steps. Each step should specify a file to edit and a high-level description of the change to make. Note that we only need to edit the files that contain functions with pass statements, ie. those in the current context. Give me only the plan, with no extraneous text.
350+
351+
You MUST precede the plan with the keyword PLAN_START, and end it with the keyword PLAN_END. You MUST follow the formatting of the example plan below, with a number preceding each step on a new line, and one file name followed by a colon and a description of the change to make:
352+
353+
PLAN_START
354+
1.) example_file.py: description of function(s) to implement in example_file.py
355+
2.) example_file2.py: description of function(s) to implement in example_file2.py
356+
...
357+
PLAN_END
358+
359+
Remember that you must modify all of the target edit files: {target_edit_files}
360+
The plan does not neccessarily need to edit the whole file in one step, and it may be more granular as you see fit.
361+
"""
351362

352363
with DirContext(repo_path):
353364
if agent_config is None:
@@ -361,51 +372,46 @@ def run_team_for_repo(
361372

362373
#TODO: add support for unit test / lint feedback
363374

364-
for f in target_edit_files:
365-
update_queue.put(("set_current_file", (repo_name, f)))
366-
dependencies = import_dependencies[f]
367375
file_name = "all"
368376
file_log_dir = experiment_log_dir / file_name
369377
lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info, commit0_config_file)
370378

371379

372-
"""
373-
#uncommenting below works, but the manager.run_manager line doesnt work idk why
380+
# #uncommenting below works, but the manager.run_manager line doesnt work idk why
374381

375-
coder_message = f"Complete the following task, implementing the relevant incomplete functions (i.e., those with pass statements). You may add the specified file to the context if necessary:"
382+
# coder_message = f"Complete the following task, implementing the relevant incomplete functions (i.e., those with pass statements). You may add the specified file to the context if necessary:"
376383

377-
agent_return = coder.run(coder_message, "", lint_cmd, target_edit_files, file_log_dir)
378-
"""
384+
# agent_return = coder.run(coder_message, "", lint_cmd, target_edit_files, file_log_dir)
379385

380386
agent_return = manager.run_manager(manager_message, target_edit_files, file_log_dir)
381387

382-
#TODO: uncomment below after figuring out why manager.run_manager doesnt work
383-
384-
# with open(agent_return.log_file, 'r', encoding='utf-8') as file:
385-
# plan = file.read()
386-
387388
# update_queue.put(
388389
# (
389390
# "update_money_display",
390391
# (repo_name, file_name, agent_return.last_cost),
391392
# )
392393
# )
394+
395+
with open(agent_return.log_file, 'r', encoding='utf-8') as file:
396+
plan = file.read()
393397

394-
# tasks = parse_tasks(plan)
398+
tasks = parse_tasks(plan)
395399

396-
# for task in tasks:
397-
# coder_message = f"Complete the following task, implementing the relevant incomplete functions (i.e., those with pass statements). You may add the specified file to the context if necessary: \n{task}"
400+
for file_name, description in tasks:
401+
update_queue.put(("set_current_file", (repo_name, file_name)))
398402

399-
# agent_return = coder.run(coder_message, "", lint_cmd, target_edit_files, file_log_dir)
400-
# #TODO: fix the display (right now it just displys one file)
401-
403+
coder_message = f"Complete the following task, implementing the relevant incomplete functions (i.e., those with pass statements): \n{description}"
404+
405+
agent_return = coder.run(coder_message, "", lint_cmd, [file_name], file_log_dir)
406+
407+
#TODO: fix the display (right now it just displys one file)
402408

403-
# update_queue.put(
404-
# (
405-
# "update_money_display",
406-
# (repo_name, file_name, agent_return.last_cost),
407-
# )
408-
# )
409+
update_queue.put(
410+
(
411+
"update_money_display",
412+
(repo_name, file_name, agent_return.last_cost),
413+
)
414+
)
409415

410416

411417

@@ -486,6 +492,7 @@ def run_agent(
486492
agent_config.use_lint_info,
487493
)
488494
display.update_branch_display(branch)
495+
489496
with multiprocessing.Manager() as manager:
490497
update_queue = manager.Queue()
491498
with multiprocessing.Pool(processes=max_parallel_repos) as pool:

0 commit comments

Comments
 (0)