Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hejia.zhang/dev1 #133

Merged
merged 2 commits into from
Feb 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Orcar/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .agent import OrcarAgent
from .edit_agent import EditAgent
from .extract_agent import ExtractAgent
from .search_agent import SearchAgent
from .trace_analysis_agent import TraceAnalysisAgent
from .verify_agent_wrapper import VerifyAgentWrapper

__all__ = [
"OrcarAgent",
"SearchAgent",
"ExtractAgent",
"TraceAnalysisAgent",
"EditAgent",
"VerifyAgentWrapper",
] # Specify the public interface of the module
82 changes: 32 additions & 50 deletions Orcar/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,19 @@
get_container,
pause_persistent_container,
)
from Orcar.extract_agent import ExtractAgent
from Orcar.log_utils import (
get_logger,
set_log_dir,
switch_log_to_file,
switch_log_to_stdout,
)
from Orcar.search import SearchManager
from Orcar.search_agent import SearchAgent
from Orcar.types import EditInput, ExtractOutput, SearchInput, SearchOutput
from Orcar.trace_analysis_agent import TraceAnalysisAgent
from Orcar.types import EditInput, SearchInput, SearchOutput, TraceAnalysisOutput


class Stage(IntEnum):
EXTRACT = 1
TRACE_ANALYSIS = 1
SEARCH = 2
EDIT = 3

Expand All @@ -57,7 +56,7 @@ def __init__(
) -> None:
"""
llm: Should be initiated outside and passed to agent construction
final_stage: Which stage will agent end at, currently support ["extract", "search"]
final_stage: Which stage will agent end at, currently support ["trace_analysis", "search"]
"""
super().__init__()
self.logger = get_logger(__name__)
Expand All @@ -71,7 +70,9 @@ def __init__(
ctr_subprocess=docker_ctr_subprocess, ctr_name=ctr_name
)
self.env = BenchmarkEnv(args, self.ctr_bash)
self.extract_agent = ExtractAgent(llm=llm, env=self.env, verbose=False)
self.trace_analysis_agent = TraceAnalysisAgent(
llm=llm, env=self.env, verbose=False
)
self.base_path = self.env.cache_dir
self.redirect_log: bool = False
self.output_to_file: bool = True
Expand Down Expand Up @@ -99,62 +100,43 @@ def set_redirect_log(self, new_value: bool) -> None:
def set_output_to_file(self, new_value: bool) -> None:
self.output_to_file = new_value

def run_extract_agent(self) -> ExtractOutput:
"""Run the extract agent."""
response: AgentChatResponse = self.extract_agent.chat(
def run_trace_analysis_agent(self) -> TraceAnalysisOutput:
"""Run the trace analysis agent."""
response: AgentChatResponse = self.trace_analysis_agent.chat(
json.dumps(dict(self.inst))
)
extract_output = ExtractOutput.model_validate_json(response.response)
self.logger.info("Raw Extract output:")
self.logger.info(extract_output)
# extract_output = self.filter_extract_output(extract_output)
# self.logger.info("Filtered extract output:")
# self.logger.info(extract_output)
trace_analysis_output = TraceAnalysisOutput.model_validate_json(
response.response
)
self.logger.info("Raw Trace Analysis output:")
self.logger.info(trace_analysis_output)

if self.output_to_file:
extract_json_obj = json.loads(extract_output.model_dump_json())
trace_analysis_json_obj = json.loads(
trace_analysis_output.model_dump_json()
)
with open(
f"{self.output_dir}/extractor_{self.inst_id}.json", "w"
f"{self.output_dir}/trace_analyzer_{self.inst_id}.json", "w"
) as handle:
json.dump(extract_json_obj, handle, indent=4)
if self.final_stage == Stage.EXTRACT:
json.dump(trace_analysis_json_obj, handle, indent=4)
if self.final_stage == Stage.TRACE_ANALYSIS:
self.output_insts.append(self.inst_id)
self.logger.info(
f"Current container subprocess: {self.env.ctr_bash.ctr_subprocess.pid}"
)

return extract_output

def filter_extract_output(self, extract_output: ExtractOutput) -> ExtractOutput:
"""Filter the extract output."""
self.logger.info("Filtering extract output with search manager...")
search_manager = SearchManager(self.repo_path)
suspicious_code = extract_output.suspicious_code
ret = []
for i, c in enumerate(suspicious_code):
keyword = c.keyword
file_path = c.file_path if c.file_path else None
ret_c = search_manager.search_callable(keyword, file_path)
if not ret_c.startswith("Cannot find the definition of"):
ret.append(c)
self.logger.info(
f"({i+1:02d}/{len(suspicious_code):02d}) Search Manager found CodeInfo {c}: \n{ret_c}"
)
else:
self.logger.info(
f"({i+1:02d}/{len(suspicious_code):02d}) Search Manager could not find CodeInfo {c}: \n{ret_c}"
)
extract_output.suspicious_code = ret
return extract_output
return trace_analysis_output

def run_search_agent(self, extract_output: ExtractOutput) -> SearchOutput:
def run_search_agent(
self, trace_analysis_output: TraceAnalysisOutput
) -> SearchOutput:
"""
Run the search agent.
It depends on the output of the extract agent.
It depends on the output of the trace analysis agent.
"""
search_input = SearchInput(
problem_statement=self.inst["problem_statement"],
extract_output=extract_output,
trace_analysis_output=trace_analysis_output,
)

self.search_agent = SearchAgent(
Expand Down Expand Up @@ -257,16 +239,16 @@ def run_agents(self) -> str:
return ""

try:
extract_output = self.run_extract_agent()
trace_analysis_output = self.run_trace_analysis_agent()
except Exception:
exc_info = sys.exc_info()
traceback.print_exception(*exc_info)
extract_output = ExtractOutput()
if self.final_stage <= Stage.EXTRACT:
return extract_output.model_dump_json(indent=4)
trace_analysis_output = TraceAnalysisOutput()
if self.final_stage <= Stage.TRACE_ANALYSIS:
return trace_analysis_output.model_dump_json(indent=4)

try:
search_output = self.run_search_agent(extract_output)
search_output = self.run_search_agent(trace_analysis_output)
except Exception:
exc_info = sys.exc_info()
traceback.print_exception(*exc_info)
Expand Down
76 changes: 44 additions & 32 deletions Orcar/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
EDIT_OUTPUT,
EDIT_REQUIREMENTS,
EDIT_SYSTEM_HEADER,
EXTRACT_EXAMPLES,
EXTRACT_FIELDS,
EXTRACT_FORMATS,
EXTRACT_PROMPTS,
FMT_CONTROL_PROMPT,
SEARCH_SYSTEM_HEADER,
STEP_EXAMPLE,
TRACE_ANALYSIS_EXAMPLES,
TRACE_ANALYSIS_FIELDS,
TRACE_ANALYSIS_FORMATS,
TRACE_ANALYSIS_PROMPTS,
)
from .types import BaseReasoningStep, SearchQueue, SearchResult

Expand Down Expand Up @@ -477,31 +477,35 @@ def format(
]


class ExtractChatFormatter(BaseAgentChatFormatter):
"""Extractor Agent formatter."""
class TraceAnalysisChatFormatter(BaseAgentChatFormatter):
"""Trace Analysis Agent formatter."""

def format(self, step: TaskStep, task: Task, handler: str) -> List[ChatMessage]:
"""Format chat history into list of ChatMessage."""
sysheader = ChatMessage(
role=MessageRole.SYSTEM, content=EXTRACT_PROMPTS["header"]
role=MessageRole.SYSTEM, content=TRACE_ANALYSIS_PROMPTS["header"]
)
if handler == "slice":
example = EXTRACT_PROMPTS["example"]
example = TRACE_ANALYSIS_PROMPTS["example"]
example_format_args = {
"example_repo_name": EXTRACT_EXAMPLES[handler]["repo_name"],
"example_input_description": EXTRACT_EXAMPLES[handler][
"example_repo_name": TRACE_ANALYSIS_EXAMPLES[handler]["repo_name"],
"example_input_description": TRACE_ANALYSIS_EXAMPLES[handler][
"input_description"
],
"example_output": "".join(
json.dumps(EXTRACT_EXAMPLES[handler]["example_output"], indent=4)
json.dumps(
TRACE_ANALYSIS_EXAMPLES[handler]["example_output"], indent=4
)
),
}
fmt_example = example.format(**example_format_args)
user_msg = EXTRACT_PROMPTS[handler]
output_format = "".join(json.dumps(EXTRACT_FORMATS[handler], indent=4))
user_msg = TRACE_ANALYSIS_PROMPTS[handler]
output_format = "".join(
json.dumps(TRACE_ANALYSIS_FORMATS[handler], indent=4)
)
format_args = {
"output_format": output_format,
"output_fields": EXTRACT_FIELDS[handler],
"output_fields": TRACE_ANALYSIS_FIELDS[handler],
"example": fmt_example,
"repo_name": task.extra_state["inst"]["repo"],
"input_description": replace_unicode_quotations(
Expand All @@ -521,25 +525,29 @@ def format(self, step: TaskStep, task: Task, handler: str) -> List[ChatMessage]:
elif handler == "parse":
step_name = step.step_state["name"]
parse_type = task.extra_state["parse_type"][step_name]
example = EXTRACT_PROMPTS["example"]
example = TRACE_ANALYSIS_PROMPTS["example"]
example_format_args = {
"example_repo_name": EXTRACT_EXAMPLES[handler][parse_type]["repo_name"],
"example_input_description": EXTRACT_EXAMPLES[handler][parse_type][
"input_description"
"example_repo_name": TRACE_ANALYSIS_EXAMPLES[handler][parse_type][
"repo_name"
],
"example_input_description": TRACE_ANALYSIS_EXAMPLES[handler][
parse_type
]["input_description"],
"example_output": "".join(
json.dumps(
EXTRACT_EXAMPLES[handler][parse_type]["example_output"],
TRACE_ANALYSIS_EXAMPLES[handler][parse_type]["example_output"],
indent=4,
)
),
}
fmt_example = example.format(**example_format_args)
user_msg = EXTRACT_PROMPTS[handler]
output_format = "".join(json.dumps(EXTRACT_FORMATS[handler], indent=4))
user_msg = TRACE_ANALYSIS_PROMPTS[handler]
output_format = "".join(
json.dumps(TRACE_ANALYSIS_FORMATS[handler], indent=4)
)
format_args = {
"output_format": output_format,
"output_fields": EXTRACT_FIELDS[handler],
"output_fields": TRACE_ANALYSIS_FIELDS[handler],
"example": fmt_example,
"repo_name": task.extra_state["inst"]["repo"],
"input_description": task.extra_state["slices"][step_name],
Expand All @@ -555,11 +563,13 @@ def format(self, step: TaskStep, task: Task, handler: str) -> List[ChatMessage]:
fmt_control_msg,
]
elif handler == "judge":
user_msg = EXTRACT_PROMPTS[handler]
output_format = "".join(json.dumps(EXTRACT_FORMATS[handler], indent=4))
user_msg = TRACE_ANALYSIS_PROMPTS[handler]
output_format = "".join(
json.dumps(TRACE_ANALYSIS_FORMATS[handler], indent=4)
)
format_args = {
"output_format": output_format,
"output_fields": EXTRACT_FIELDS[handler],
"output_fields": TRACE_ANALYSIS_FIELDS[handler],
"repo_name": task.extra_state["inst"]["repo"],
"input_description": task.extra_state["inst"]["problem_statement"],
"reproducer_snippet": task.extra_state["slices"][
Expand All @@ -578,25 +588,27 @@ def format(self, step: TaskStep, task: Task, handler: str) -> List[ChatMessage]:
fmt_control_msg,
]
elif handler == "summarize":
user_msg = EXTRACT_PROMPTS[handler]
example = EXTRACT_PROMPTS["example"]
user_msg = TRACE_ANALYSIS_PROMPTS[handler]
example = TRACE_ANALYSIS_PROMPTS["example"]
example_format_args = {
"example_repo_name": EXTRACT_EXAMPLES[handler]["repo_name"],
"example_input_description": EXTRACT_EXAMPLES[handler][
"example_repo_name": TRACE_ANALYSIS_EXAMPLES[handler]["repo_name"],
"example_input_description": TRACE_ANALYSIS_EXAMPLES[handler][
"input_description"
],
"example_output": "".join(
json.dumps(
EXTRACT_EXAMPLES[handler]["example_output"],
TRACE_ANALYSIS_EXAMPLES[handler]["example_output"],
indent=4,
)
),
}
fmt_example = example.format(**example_format_args)
output_format = "".join(json.dumps(EXTRACT_FORMATS[handler], indent=4))
output_format = "".join(
json.dumps(TRACE_ANALYSIS_FORMATS[handler], indent=4)
)
format_args = {
"output_format": output_format,
"output_fields": EXTRACT_FIELDS[handler],
"output_fields": TRACE_ANALYSIS_FIELDS[handler],
"example": fmt_example,
"repo_name": task.extra_state["inst"]["repo"],
"input_description": task.extra_state["inst"]["problem_statement"],
Expand Down
20 changes: 10 additions & 10 deletions Orcar/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
BugLocations,
CodeInfo,
EditOutput,
ExtractJudgeStep,
ExtractParseStep,
ExtractSliceStep,
ExtractSummarizeStep,
SearchActionStep,
TraceAnalysisJudgeStep,
TraceAnalysisParseStep,
TraceAnalysisSliceStep,
TraceAnalysisSummarizeStep,
)

logger = get_logger(__name__)
Expand Down Expand Up @@ -230,8 +230,8 @@ def parse(self, output: str) -> EditOutput:
raise ValueError(f"Could not parse edit output: {output}")


class ExtractOutputParser(BaseOutputParser):
"""Extractor Agent formatter."""
class TraceAnalysisOutputParser(BaseOutputParser):
"""Trace Analysis Agent formatter."""

def parse(self, output: str, method: str) -> BaseReasoningStep:
try:
Expand All @@ -244,7 +244,7 @@ def parse(self, output: str, method: str) -> BaseReasoningStep:
logger.info(err_msg_io.getvalue())
json_obj: Dict = json.loads(output.replace("\\", r"\\"), strict=False)
if method == "slice":
return ExtractSliceStep(
return TraceAnalysisSliceStep(
traceback_warning_log_slice=json_obj["traceback_warning_log_slice"],
issue_reproducer_slice=json_obj["issue_reproducer_slice"],
source_code_slice=json_obj["source_code_slice"],
Expand All @@ -254,9 +254,9 @@ def parse(self, output: str, method: str) -> BaseReasoningStep:
CodeInfo(keyword=x["keyword"], file_path=x["file_path"])
for x in json_obj["code_info_list"]
]
return ExtractParseStep(code_info_list=code_info_list)
return TraceAnalysisParseStep(code_info_list=code_info_list)
elif method == "judge":
return ExtractJudgeStep(
return TraceAnalysisJudgeStep(
is_successful=json_obj["is_successful"],
fixed_reproduce_snippet=json_obj["fixed_reproduce_snippet"],
)
Expand All @@ -265,7 +265,7 @@ def parse(self, output: str, method: str) -> BaseReasoningStep:
CodeInfo(keyword=x["keyword"], file_path=x["file_path"])
for x in json_obj["code_info_list"]
]
return ExtractSummarizeStep(
return TraceAnalysisSummarizeStep(
summary=json_obj["summary"], code_info_list=code_info_list
)
raise NotImplementedError
Loading