Skip to content

Commit

Permalink
add functions to evaluate feedstock sustainability
Browse files Browse the repository at this point in the history
  • Loading branch information
du-phan committed Jan 25, 2025
1 parent 65e07d8 commit 0297b05
Show file tree
Hide file tree
Showing 8 changed files with 510 additions and 29 deletions.
50 changes: 25 additions & 25 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions sequestrae_engine/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,35 @@ def extract_audit_information_command(api_key, project_dir, limit=100):

print(f"Successfully processed {markdown_count} markdown files")
return 0


def evaluate_feedstock_sustainability_command(api_key, project_dir, limit=100):
if limit is None:
limit = 100
print(f"Max number of files to process: {limit}")

if not api_key:
print("Error: MISTRAL_API_KEY is required")
return 1

extractor = AuditReportExtractor(api_key=api_key)
project_path = Path(project_dir)

if not project_path.exists():
print(f"Error: Directory not found at {project_path}")
return 1

markdown_count = 0
for folder_path in project_path.iterdir():
if not folder_path.name.startswith(".") and folder_path.is_dir() and markdown_count < limit:
for md_path in folder_path.glob("*.md"):
if md_path.is_file() and "report" in md_path.stem.lower():
markdown_count += 1
try:
extractor.analyze_feedstock_sustainability(audit_path=md_path)
except Exception as e:
print(f"Error processing {md_path}: {str(e)}")
time.sleep(1) # Sleep for 1 second to avoid rate limiting

print(f"Successfully processed {markdown_count} markdown files")
return 0
19 changes: 19 additions & 0 deletions sequestrae_engine/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,25 @@ def main():
)
)

# Evaluate feedstock sustainability command
evaluate_feedstock_parser = subparsers.add_parser(
"evaluate-feedstock", help="Evaluate feedstock sustainability from markdown report files"
)
evaluate_feedstock_parser.add_argument(
"--mistral-api-key", help="Mistral API key", required=True
)
evaluate_feedstock_parser.add_argument(
"--project-dir", help="Project data directory", required=True
)
evaluate_feedstock_parser.add_argument(
"--limit", help="Maximum number of files to process", type=int
)
evaluate_feedstock_parser.set_defaults(
func=lambda args: commands.evaluate_feedstock_sustainability_command(
args.mistral_api_key, args.project_dir, args.limit
)
)

args = parser.parse_args()
if hasattr(args, "func"):
exit_status = args.func(args)
Expand Down
100 changes: 96 additions & 4 deletions sequestrae_engine/document_parsing/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,25 @@
import logging
import os
import re
import time
from typing import Dict, List

from mistralai import Mistral

from sequestrae_engine.core.utilities import load_json_file

# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# Define paths relative to this file
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SYSTEM_PROMPT_PATH = os.path.join(SCRIPT_DIR, "prompts/system_prompt.txt")
FEEDSTOCK_PROMPT_PATH = os.path.join(SCRIPT_DIR, "prompts/feedstock_prompt.txt")
FEEDSTOCK_CRITERIA_PATH = os.path.join(SCRIPT_DIR, "prompts/feedstock_criteria.json")


def read_markdown_file(filepath: str) -> str:
"""
Expand Down Expand Up @@ -65,10 +76,7 @@ def parse_audit_report(self, audit_report_path, output_folder_path=None, overwri
return audit_report_dict

def extract_audit_report_data(self, report_content):
# load the system prompt file
script_dir = os.path.dirname(os.path.abspath(__file__))
prompt_path = os.path.join(script_dir, "system_prompt.txt")
with open(prompt_path, "r") as file:
with open(SYSTEM_PROMPT_PATH, "r") as file:
system_prompt = file.read()

user_message_template = """
Expand Down Expand Up @@ -109,3 +117,87 @@ def extract_audit_report_data(self, report_content):
audit_dict = json.loads(audit_info)

return audit_dict

def analyze_feedstock_sustainability(
self, audit_path: str, output_folder_path=None, overwrite=False
) -> List[Dict]:
"""
Analyze feedstock sustainability from audit report using Mistral LLM.
Args:
audit_path: Path to the audit report file
output_folder_path: Optional path to output folder. If None, uses same directory as input
overwrite: Whether to overwrite existing output file
Returns:
List of analysis results for each topic
"""
# Load prompts and criteria
with open(FEEDSTOCK_PROMPT_PATH, "r") as f:
context_content = f.read()
with open(audit_path, "r") as file:
audit_report = file.read()

criteria_guideline = load_json_file(FEEDSTOCK_CRITERIA_PATH)

# Template for the full message
full_message_template = """
{context_content}
**Criteria Guideline**
Topic: {topic}
{questions}
**Audit report**
{audit_report}
"""

result_list = []

# Process each topic in the criteria guideline
for topic in criteria_guideline.keys():
start_time = time.time()
questions = criteria_guideline.get(topic)

full_message = full_message_template.format(
context_content=context_content,
topic=topic,
questions=questions,
audit_report=audit_report,
)

messages = [{"role": "user", "content": full_message}]

chat_response = self.mistral_client.chat.complete(
model=self.model, messages=messages, response_format={"type": "json_object"}
)

response_content = chat_response.choices[0].message.content
response_content_dict = json.loads(response_content)
result_list.append(response_content_dict)

logger.info(f'Processed topic "{topic}" in {time.time() - start_time:.2f}s')
time.sleep(1) # Rate limiting

# Determine output path
input_dir = os.path.dirname(audit_path)
input_filename = os.path.basename(audit_path)
output_filename = os.path.splitext(input_filename)[0] + "_feedstock_analysis.json"

output_path = os.path.join(
output_folder_path if output_folder_path else input_dir, output_filename
)

# Check if file exists and overwrite is False
if os.path.exists(output_path) and not overwrite:
logger.info(f"Output file already exists at {output_path} and overwrite=False")
return result_list

# Save the JSON file
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result_list, f, indent=2)

logger.info(f"Feedstock analysis complete. Results saved to {output_path}")
return result_list
Loading

0 comments on commit 0297b05

Please sign in to comment.