9
9
from import_deps import ModuleSet
10
10
from graphlib import TopologicalSorter , CycleError
11
11
import yaml
12
-
12
+ from rank_bm25 import BM25Okapi
13
13
from agent .class_types import AgentConfig
14
+ import subprocess
14
15
15
16
PROMPT_HEADER = ">>> Here is the Task:\n "
17
+ FUNCTION_HEADER = "\n \n >>> Here are all functions in the file, complete the implementations for all functions (i.e., those with pass statements):\n "
16
18
REFERENCE_HEADER = "\n \n >>> Here is the Reference for you to finish the task:\n "
17
19
REPO_INFO_HEADER = "\n \n >>> Here is the Repository Information:\n "
18
20
UNIT_TESTS_INFO_HEADER = "\n \n >>> Here are the Unit Tests Information:\n "
19
21
LINT_INFO_HEADER = "\n \n >>> Here is the Lint Information:\n "
20
22
SPEC_INFO_HEADER = "\n \n >>> Here is the Specification Information:\n "
21
23
IMPORT_DEPENDENCIES_HEADER = "\n \n >>> Here are the Import Dependencies:\n "
24
+ FUNCTION_BY_FUNCTION_HEADER = """"\n Your task is to implement function {unimplemented_functions} by replacing the pass statement with actual functional code.
25
+ Please note that there could be multiple occurrences of {unimplemented_functions}, and you need to implement them all.
26
+ Do not change the names of existing functions or classes, as they may be referenced from other code like unit tests, etc.
27
+ When you generate code, you must maintain the original formatting of the function stubs (such as whitespaces), otherwise we will not able to search/replace blocks for code modifications, and therefore you will receive a score of 0 for your generated code."""
22
28
# prefix components:
23
29
space = " "
24
30
branch = "│ "
@@ -123,6 +129,32 @@ def get_file_info(file_path: Path, prefix: str = "") -> str:
123
129
return "\n " .join (filter (None , tree_string ))
124
130
125
131
132
+ def get_unimplemented_functions (file_path : Path ) -> List [str ]:
133
+ """Get all the functions in a file."""
134
+ with open (file_path , "r" ) as f :
135
+ content = f .read ()
136
+
137
+ # Find all function definitions with their bodies
138
+ pattern = r"def\s+(\w+)\s*\([^)]*\)[^:]*:(?:\s*(?:'''[\s\S]*?'''|\"\"\"[\s\S]*?\"\"\"))?\s*((?:(?!\ndef\s+).)*?)(?=\s*def\s+|\s*$)"
139
+ matches = re .finditer (pattern , content , re .DOTALL )
140
+
141
+ # Keep only functions that have just 'pass'
142
+ # List to store unimplemented function definitions
143
+ unimplemented_functions = []
144
+ for match in matches :
145
+ func_name = match .group (1 )
146
+ func_body = match .group (2 ).strip ()
147
+ # Check if function only contains 'pass' statement
148
+ if "pass" in func_body :
149
+ unimplemented_functions .append (f"def { func_name } ()" )
150
+ # # Find the full function definition using regex pattern
151
+ # func_pattern = rf"def\s+{func_name}\s*\([^)]*\)[^:]*:"
152
+ # func_match = re.search(func_pattern, content)
153
+ # if func_match:
154
+ # unimplemented.append(func_match.group(0))
155
+ return unimplemented_functions
156
+
157
+
126
158
def collect_test_files (directory : str ) -> list [str ]:
127
159
"""Collect all the test files in the directory."""
128
160
test_files = []
@@ -265,9 +297,9 @@ def get_target_edit_files(
265
297
raise ValueError (
266
298
"topological_sort_files should not be longer than filtered_files"
267
299
)
268
- assert len (topological_sort_files ) == len (
269
- filtered_files
270
- ), "all files should be included"
300
+ assert len (topological_sort_files ) == len (filtered_files ), (
301
+ "all files should be included"
302
+ )
271
303
272
304
# change to latest commit
273
305
local_repo .git .checkout (branch )
@@ -324,9 +356,9 @@ def get_target_edit_files_from_patch(
324
356
raise ValueError (
325
357
"topological_sort_files should not be longer than target_files_list"
326
358
)
327
- assert len (topological_sort_files ) == len (
328
- target_files_list
329
- ), "all files should be included"
359
+ assert len (topological_sort_files ) == len (target_files_list ), (
360
+ "all files should be included"
361
+ )
330
362
331
363
topological_sort_files = [
332
364
file .replace (working_dir , "" ).lstrip ("/" ) for file in topological_sort_files
@@ -347,6 +379,7 @@ def get_message(
347
379
agent_config : AgentConfig ,
348
380
repo_path : str ,
349
381
test_files : list [str ] | None = None ,
382
+ input_file : str | None = None ,
350
383
) -> str :
351
384
"""Get the message to Aider."""
352
385
prompt = f"{ PROMPT_HEADER } " + agent_config .user_prompt
@@ -383,11 +416,11 @@ def get_message(
383
416
with bz2 .open ("spec.pdf.bz2" , "rb" ) as in_file :
384
417
with open ("spec.pdf" , "wb" ) as out_file :
385
418
out_file .write (in_file .read ())
386
- spec_info = (
387
- f" \n { SPEC_INFO_HEADER } "
388
- + get_specification ( specification_pdf_path = Path ( repo_path , "spec.pdf" ))[
389
- : agent_config . max_spec_info_length
390
- ]
419
+ spec_info = f" \n { SPEC_INFO_HEADER } " + get_specification (
420
+ specification_pdf_path = Path ( repo_path , "spec.pdf" ),
421
+ use_retrieval = True ,
422
+ query = input_file if input_file else "" ,
423
+ top_k = 10 ,
391
424
)
392
425
else :
393
426
spec_info = ""
@@ -397,6 +430,39 @@ def get_message(
397
430
return message_to_agent
398
431
399
432
433
+ def get_message_function_by_function (
434
+ agent_config : AgentConfig ,
435
+ repo_path : str ,
436
+ input_file : str ,
437
+ test_files : list [str ] | None = None ,
438
+ ) -> list [str ]:
439
+ """Get the message to Aider."""
440
+ context = get_message (agent_config , repo_path , test_files )
441
+
442
+ if agent_config .implementation_strategy == "module_by_module" :
443
+ function_info = []
444
+ elif agent_config .implementation_strategy == "function_by_function" :
445
+ function_info = []
446
+ unimplemented_functions = get_unimplemented_functions (
447
+ file_path = Path (os .path .join (repo_path , input_file ))
448
+ )
449
+ # Get the original function stubs and filter out implemented functions
450
+ for i in range (len (unimplemented_functions )):
451
+ function_info .append (
452
+ FUNCTION_BY_FUNCTION_HEADER .format (
453
+ unimplemented_functions = unimplemented_functions [i ]
454
+ )
455
+ )
456
+ else :
457
+ raise ValueError (
458
+ f"Invalid implementation strategy: { agent_config .implementation_strategy } "
459
+ )
460
+
461
+ messages_to_agent = [context + uf for uf in unimplemented_functions ]
462
+
463
+ return messages_to_agent
464
+
465
+
400
466
def update_message_with_dependencies (message : str , dependencies : list [str ]) -> str :
401
467
"""Update the message with the dependencies."""
402
468
if len (dependencies ) == 0 :
@@ -411,19 +477,43 @@ def update_message_with_dependencies(message: str, dependencies: list[str]) -> s
411
477
return message
412
478
413
479
414
- def get_specification (specification_pdf_path : Path ) -> str :
480
+ def get_specification (
481
+ specification_pdf_path : Path ,
482
+ use_retrieval : bool = True ,
483
+ query : str = "" ,
484
+ top_k : int = 20 ,
485
+ ) -> str :
415
486
"""Get the reference for a given specification PDF path."""
416
487
# TODO: after pdf_to_text is available, use it to extract the text from the PDF
417
488
# Open the specified PDF file
489
+
418
490
document = fitz .open (specification_pdf_path )
419
- text = ""
491
+ corpus = []
420
492
493
+ # current_trunk = ""
421
494
# Iterate through the pages
422
495
for page_num in range (len (document )):
423
496
page = document .load_page (page_num ) # loads the specified page
424
- text += page .get_text () # type: ignore
425
497
426
- return text
498
+ current_page_text = page .get_text () # type: ignore
499
+ # Cut page text into chunks of 1000 characters
500
+ text_chunks = [
501
+ current_page_text [i : i + 1000 ]
502
+ for i in range (0 , len (current_page_text ), 1000 )
503
+ ]
504
+ corpus .extend (text_chunks )
505
+ # corpus.append(page.get_text()) # type: ignore
506
+ if not use_retrieval :
507
+ return "\n " .join (corpus )
508
+
509
+ assert query != "" , "query should not be empty"
510
+ query = open (query ).read ()
511
+ tokenized_corpus = [doc .split (" " ) for doc in corpus ]
512
+ bm25 = BM25Okapi (tokenized_corpus )
513
+ doc_scores = bm25 .get_scores (query )
514
+ sorted_doc_scores = sorted (enumerate (doc_scores ), key = lambda x : x [1 ], reverse = True )
515
+ sorted_doc_indices = [i for i , _ in sorted_doc_scores ]
516
+ return "\n " .join (corpus [i ] for i in sorted_doc_indices [:top_k ])
427
517
428
518
429
519
def create_branch (repo : git .Repo , branch : str , from_commit : str ) -> None :
@@ -486,6 +576,21 @@ def get_changed_files_from_commits(
486
576
return []
487
577
488
578
579
+ def run_eval_after_each_commit (
580
+ branch : str , backend : str , commit0_config_file : str , repo_name : str
581
+ ) -> str :
582
+ """Run the eval command after each commit."""
583
+ eval_cmd = f"python -m commit0 evaluate --branch { branch } --backend { backend } --commit0-config-file { commit0_config_file } --timeout 100"
584
+ try :
585
+ result = subprocess .run (
586
+ eval_cmd , shell = True , capture_output = True , text = True , check = True
587
+ )
588
+ return result .stdout
589
+ except subprocess .CalledProcessError as e :
590
+ print (f"Error running eval command: { e } " )
591
+ return e .stdout if e .stdout else str (e )
592
+
593
+
489
594
def args2string (agent_config : AgentConfig ) -> str :
490
595
"""Converts specific fields from an `AgentConfig` object into a formatted string.
491
596
0 commit comments