Skip to content

Commit cfbef0e

Browse files
committed
update agent ablation code
1 parent 4043a34 commit cfbef0e

File tree

8 files changed

+354
-330
lines changed

8 files changed

+354
-330
lines changed

.gitignore

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,10 @@ cython_debug/
162162
#.idea/
163163

164164
logs/
165-
repos/
165+
repos*/
166166
config.yml
167167
hydra_outputs/
168168
.commit0*
169169
.agent*
170-
docs/analysis*.md
170+
docs/analysis*.md
171+
agent/run_agent_no_rich.py

agent/agent_utils.py

Lines changed: 121 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,22 @@
99
from import_deps import ModuleSet
1010
from graphlib import TopologicalSorter, CycleError
1111
import yaml
12-
12+
from rank_bm25 import BM25Okapi
1313
from agent.class_types import AgentConfig
14+
import subprocess
1415

1516
PROMPT_HEADER = ">>> Here is the Task:\n"
17+
FUNCTION_HEADER = "\n\n>>> Here are all functions in the file, complete the implementations for all functions (i.e., those with pass statements):\n"
1618
REFERENCE_HEADER = "\n\n>>> Here is the Reference for you to finish the task:\n"
1719
REPO_INFO_HEADER = "\n\n>>> Here is the Repository Information:\n"
1820
UNIT_TESTS_INFO_HEADER = "\n\n>>> Here are the Unit Tests Information:\n"
1921
LINT_INFO_HEADER = "\n\n>>> Here is the Lint Information:\n"
2022
SPEC_INFO_HEADER = "\n\n>>> Here is the Specification Information:\n"
2123
IMPORT_DEPENDENCIES_HEADER = "\n\n>>> Here are the Import Dependencies:\n"
24+
FUNCTION_BY_FUNCTION_HEADER = """"\nYour task is to implement function {unimplemented_functions} by replacing the pass statement with actual functional code.
25+
Please note that there could be multiple occurrences of {unimplemented_functions}, and you need to implement them all.
26+
Do not change the names of existing functions or classes, as they may be referenced from other code like unit tests, etc.
27+
When you generate code, you must maintain the original formatting of the function stubs (such as whitespaces), otherwise we will not able to search/replace blocks for code modifications, and therefore you will receive a score of 0 for your generated code."""
2228
# prefix components:
2329
space = " "
2430
branch = "│ "
@@ -123,6 +129,32 @@ def get_file_info(file_path: Path, prefix: str = "") -> str:
123129
return "\n".join(filter(None, tree_string))
124130

125131

132+
def get_unimplemented_functions(file_path: Path) -> List[str]:
133+
"""Get all the functions in a file."""
134+
with open(file_path, "r") as f:
135+
content = f.read()
136+
137+
# Find all function definitions with their bodies
138+
pattern = r"def\s+(\w+)\s*\([^)]*\)[^:]*:(?:\s*(?:'''[\s\S]*?'''|\"\"\"[\s\S]*?\"\"\"))?\s*((?:(?!\ndef\s+).)*?)(?=\s*def\s+|\s*$)"
139+
matches = re.finditer(pattern, content, re.DOTALL)
140+
141+
# Keep only functions that have just 'pass'
142+
# List to store unimplemented function definitions
143+
unimplemented_functions = []
144+
for match in matches:
145+
func_name = match.group(1)
146+
func_body = match.group(2).strip()
147+
# Check if function only contains 'pass' statement
148+
if "pass" in func_body:
149+
unimplemented_functions.append(f"def {func_name}()")
150+
# # Find the full function definition using regex pattern
151+
# func_pattern = rf"def\s+{func_name}\s*\([^)]*\)[^:]*:"
152+
# func_match = re.search(func_pattern, content)
153+
# if func_match:
154+
# unimplemented.append(func_match.group(0))
155+
return unimplemented_functions
156+
157+
126158
def collect_test_files(directory: str) -> list[str]:
127159
"""Collect all the test files in the directory."""
128160
test_files = []
@@ -265,9 +297,9 @@ def get_target_edit_files(
265297
raise ValueError(
266298
"topological_sort_files should not be longer than filtered_files"
267299
)
268-
assert len(topological_sort_files) == len(
269-
filtered_files
270-
), "all files should be included"
300+
assert len(topological_sort_files) == len(filtered_files), (
301+
"all files should be included"
302+
)
271303

272304
# change to latest commit
273305
local_repo.git.checkout(branch)
@@ -324,9 +356,9 @@ def get_target_edit_files_from_patch(
324356
raise ValueError(
325357
"topological_sort_files should not be longer than target_files_list"
326358
)
327-
assert len(topological_sort_files) == len(
328-
target_files_list
329-
), "all files should be included"
359+
assert len(topological_sort_files) == len(target_files_list), (
360+
"all files should be included"
361+
)
330362

331363
topological_sort_files = [
332364
file.replace(working_dir, "").lstrip("/") for file in topological_sort_files
@@ -347,6 +379,7 @@ def get_message(
347379
agent_config: AgentConfig,
348380
repo_path: str,
349381
test_files: list[str] | None = None,
382+
input_file: str | None = None,
350383
) -> str:
351384
"""Get the message to Aider."""
352385
prompt = f"{PROMPT_HEADER}" + agent_config.user_prompt
@@ -383,11 +416,11 @@ def get_message(
383416
with bz2.open("spec.pdf.bz2", "rb") as in_file:
384417
with open("spec.pdf", "wb") as out_file:
385418
out_file.write(in_file.read())
386-
spec_info = (
387-
f"\n{SPEC_INFO_HEADER} "
388-
+ get_specification(specification_pdf_path=Path(repo_path, "spec.pdf"))[
389-
: agent_config.max_spec_info_length
390-
]
419+
spec_info = f"\n{SPEC_INFO_HEADER} " + get_specification(
420+
specification_pdf_path=Path(repo_path, "spec.pdf"),
421+
use_retrieval=True,
422+
query=input_file if input_file else "",
423+
top_k=10,
391424
)
392425
else:
393426
spec_info = ""
@@ -397,6 +430,39 @@ def get_message(
397430
return message_to_agent
398431

399432

433+
def get_message_function_by_function(
434+
agent_config: AgentConfig,
435+
repo_path: str,
436+
input_file: str,
437+
test_files: list[str] | None = None,
438+
) -> list[str]:
439+
"""Get the message to Aider."""
440+
context = get_message(agent_config, repo_path, test_files)
441+
442+
if agent_config.implementation_strategy == "module_by_module":
443+
function_info = []
444+
elif agent_config.implementation_strategy == "function_by_function":
445+
function_info = []
446+
unimplemented_functions = get_unimplemented_functions(
447+
file_path=Path(os.path.join(repo_path, input_file))
448+
)
449+
# Get the original function stubs and filter out implemented functions
450+
for i in range(len(unimplemented_functions)):
451+
function_info.append(
452+
FUNCTION_BY_FUNCTION_HEADER.format(
453+
unimplemented_functions=unimplemented_functions[i]
454+
)
455+
)
456+
else:
457+
raise ValueError(
458+
f"Invalid implementation strategy: {agent_config.implementation_strategy}"
459+
)
460+
461+
messages_to_agent = [context + uf for uf in unimplemented_functions]
462+
463+
return messages_to_agent
464+
465+
400466
def update_message_with_dependencies(message: str, dependencies: list[str]) -> str:
401467
"""Update the message with the dependencies."""
402468
if len(dependencies) == 0:
@@ -411,19 +477,43 @@ def update_message_with_dependencies(message: str, dependencies: list[str]) -> s
411477
return message
412478

413479

414-
def get_specification(specification_pdf_path: Path) -> str:
480+
def get_specification(
481+
specification_pdf_path: Path,
482+
use_retrieval: bool = True,
483+
query: str = "",
484+
top_k: int = 20,
485+
) -> str:
415486
"""Get the reference for a given specification PDF path."""
416487
# TODO: after pdf_to_text is available, use it to extract the text from the PDF
417488
# Open the specified PDF file
489+
418490
document = fitz.open(specification_pdf_path)
419-
text = ""
491+
corpus = []
420492

493+
# current_trunk = ""
421494
# Iterate through the pages
422495
for page_num in range(len(document)):
423496
page = document.load_page(page_num) # loads the specified page
424-
text += page.get_text() # type: ignore
425497

426-
return text
498+
current_page_text = page.get_text() # type: ignore
499+
# Cut page text into chunks of 1000 characters
500+
text_chunks = [
501+
current_page_text[i : i + 1000]
502+
for i in range(0, len(current_page_text), 1000)
503+
]
504+
corpus.extend(text_chunks)
505+
# corpus.append(page.get_text()) # type: ignore
506+
if not use_retrieval:
507+
return "\n".join(corpus)
508+
509+
assert query != "", "query should not be empty"
510+
query = open(query).read()
511+
tokenized_corpus = [doc.split(" ") for doc in corpus]
512+
bm25 = BM25Okapi(tokenized_corpus)
513+
doc_scores = bm25.get_scores(query)
514+
sorted_doc_scores = sorted(enumerate(doc_scores), key=lambda x: x[1], reverse=True)
515+
sorted_doc_indices = [i for i, _ in sorted_doc_scores]
516+
return "\n".join(corpus[i] for i in sorted_doc_indices[:top_k])
427517

428518

429519
def create_branch(repo: git.Repo, branch: str, from_commit: str) -> None:
@@ -486,6 +576,21 @@ def get_changed_files_from_commits(
486576
return []
487577

488578

579+
def run_eval_after_each_commit(
580+
branch: str, backend: str, commit0_config_file: str, repo_name: str
581+
) -> str:
582+
"""Run the eval command after each commit."""
583+
eval_cmd = f"python -m commit0 evaluate --branch {branch} --backend {backend} --commit0-config-file {commit0_config_file} --timeout 100"
584+
try:
585+
result = subprocess.run(
586+
eval_cmd, shell=True, capture_output=True, text=True, check=True
587+
)
588+
return result.stdout
589+
except subprocess.CalledProcessError as e:
590+
print(f"Error running eval command: {e}")
591+
return e.stdout if e.stdout else str(e)
592+
593+
489594
def args2string(agent_config: AgentConfig) -> str:
490595
"""Converts specific fields from an `AgentConfig` object into a formatted string.
491596

agent/agents.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def __init__(self, log_file: Path):
2727
self.log_file = log_file
2828

2929
self.last_cost = 0.0
30+
self.total_token_in = 0
31+
self.total_token_out = 0
3032

3133

3234
class Agents(ABC):
@@ -43,6 +45,8 @@ class AiderReturn(AgentReturn):
4345
def __init__(self, log_file: Path):
4446
super().__init__(log_file)
4547
self.last_cost = self.get_money_cost()
48+
self.total_token_in = self.get_total_token_in()
49+
self.total_token_out = self.get_total_token_out()
4650

4751
def get_money_cost(self) -> float:
4852
"""Get accumulated money cost from log file"""
@@ -57,18 +61,54 @@ def get_money_cost(self) -> float:
5761
last_cost = float(match.group(1))
5862
return last_cost
5963

64+
def get_total_token_in(self) -> int:
65+
"""Get total token in from log file"""
66+
total_tokens = 0
67+
with open(self.log_file, "r") as file:
68+
for line in file:
69+
if "Tokens:" in line:
70+
match = re.search(r"Tokens: ([\d.]+k?) sent", line)
71+
if match:
72+
token_str = match.group(1)
73+
if token_str.endswith("k"):
74+
total_tokens = int(float(token_str[:-1]) * 1000)
75+
else:
76+
total_tokens = int(float(token_str))
77+
return total_tokens
78+
79+
def get_total_token_out(self) -> int:
80+
"""Get total token out from log file"""
81+
total_tokens = 0
82+
with open(self.log_file, "r") as file:
83+
for line in file:
84+
if "Tokens:" in line:
85+
match = re.search(r"(\d+) received", line)
86+
if match:
87+
total_str = match.group(1)
88+
if total_str.endswith("k"):
89+
total_tokens = int(float(total_str[:-1]) * 1000)
90+
else:
91+
total_tokens = int(float(total_str))
92+
return total_tokens
93+
6094

6195
class AiderAgents(Agents):
6296
def __init__(self, max_iteration: int, model_name: str):
6397
super().__init__(max_iteration)
6498
self.model = Model(model_name)
6599
# Check if API key is set for the model
66-
if "gpt" in model_name:
100+
if "openrouter" in model_name:
101+
api_key = os.environ.get("OPENROUTER_API_KEY", None)
102+
elif "gpt" in model_name:
67103
api_key = os.environ.get("OPENAI_API_KEY", None)
68104
elif "claude" in model_name:
69105
api_key = os.environ.get("ANTHROPIC_API_KEY", None)
70106
elif "gemini" in model_name:
71-
api_key = os.environ.get("API_KEY", None)
107+
api_key = os.environ.get("GEMINI_API_KEY", None)
108+
elif "deepseek" in model_name:
109+
api_key = os.environ.get("DEEPSEEK_API_KEY", None)
110+
elif "mistral" in model_name:
111+
api_key = os.environ.get("MISTRAL_API_KEY", None)
72112
else:
73113
raise ValueError(f"Unsupported model: {model_name}")
74114

@@ -87,6 +127,7 @@ def run(
87127
log_dir: Path,
88128
test_first: bool = False,
89129
lint_first: bool = False,
130+
current_attempt: int = 0,
90131
) -> AgentReturn:
91132
"""Start aider agent"""
92133
if test_cmd:
@@ -99,11 +140,22 @@ def run(
99140
auto_lint = False
100141
log_dir = log_dir.resolve()
101142
log_dir.mkdir(parents=True, exist_ok=True)
102-
input_history_file = log_dir / ".aider.input.history"
103-
chat_history_file = log_dir / ".aider.chat.history.md"
104-
143+
input_history_file = (
144+
log_dir / ".aider.input.history"
145+
if current_attempt == 0
146+
else log_dir / f".aider_{current_attempt}.input.history"
147+
)
148+
chat_history_file = (
149+
log_dir / ".aider.chat.history.md"
150+
if current_attempt == 0
151+
else log_dir / f".aider_{current_attempt}.chat.history.md"
152+
)
105153
# Set up logging
106-
log_file = log_dir / "aider.log"
154+
log_file = (
155+
log_dir / "aider.log"
156+
if current_attempt == 0
157+
else log_dir / f"aider_{current_attempt}.log"
158+
)
107159
logging.basicConfig(
108160
filename=log_file,
109161
level=logging.INFO,
@@ -133,7 +185,7 @@ def run(
133185
io=io,
134186
)
135187
coder.max_reflections = self.max_iteration
136-
coder.stream = True
188+
coder.stream = False
137189

138190
# Run the agent
139191
if test_first:

agent/class_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ class AgentConfig:
2222
run_tests: bool
2323
max_iteration: int
2424
record_test_for_each_commit: bool
25+
implementation_strategy: str
26+
repeat_times_for_each_inquiry: int

agent/cli.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,14 @@ def config(
135135
False,
136136
help="Run the lint on the entire directory",
137137
),
138+
implementation_strategy: str = typer.Option(
139+
"module_by_module",
140+
help="Implementation strategy to use",
141+
),
142+
repeat_times_for_each_inquiry: int = typer.Option(
143+
1,
144+
help="Repeat times for each inquiry",
145+
),
138146
record_test_for_each_commit: bool = typer.Option(
139147
False,
140148
help="Record the test for each commit",
@@ -173,6 +181,8 @@ def config(
173181
"use_lint_info": use_lint_info,
174182
"max_lint_info_length": max_lint_info_length,
175183
"run_entire_dir_lint": run_entire_dir_lint,
184+
"implementation_strategy": implementation_strategy,
185+
"repeat_times_for_each_inquiry": repeat_times_for_each_inquiry,
176186
"pre_commit_config_path": pre_commit_config_path,
177187
"record_test_for_each_commit": record_test_for_each_commit,
178188
}

0 commit comments

Comments
 (0)