Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ You can either:
```
OPENAI_API_KEY: 'xxxxxxx'
ANTHROPIC_API_KEY: 'xxxxxxx'
VERTEX_SERVICE_ACCOUNT_PATH: 'xxxxxxx'
VERTEX_REGION: 'xxxxxxx'
```

### To install iverilog {.tabset}
Expand Down
87 changes: 72 additions & 15 deletions src/mage_rtl/gen_config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os

import config
from google.oauth2 import service_account
from llama_index.core.llms.llm import LLM
from llama_index.llms.anthropic import Anthropic
from llama_index.llms.openai import OpenAI
from llama_index.llms.vertex import Vertex
from pydantic import BaseModel

from .log_utils import get_logger
from .utils import VertexAnthropicWithCredentials

logger = get_logger(__name__)

Expand Down Expand Up @@ -34,27 +37,81 @@ def __getitem__(self, index):


def get_llm(**kwargs) -> LLM:
LLM_func = Anthropic
cfg = Config(kwargs["cfg_path"])
api_key_cfg = ""
if kwargs["provider"] == "anthropic":
LLM_func = Anthropic
api_key_cfg = cfg["ANTHROPIC_API_KEY"]
provider: str = kwargs["provider"]
provider = provider.lower()
if provider == "anthropic":
try:
llm: LLM = Anthropic(
model=kwargs["model"],
api_key=cfg["ANTHROPIC_API_KEY"],
max_tokens=kwargs["max_token"],
)

except Exception as e:
raise Exception(f"gen_config: Failed to get {provider} LLM") from e
elif kwargs["provider"] == "openai":
LLM_func = OpenAI
api_key_cfg = cfg["OPENAI_API_KEY"]
# add more providers if needed
try:
llm: LLM = OpenAI(
model=kwargs["model"],
api_key=cfg["OPENAI_API_KEY"],
max_tokens=kwargs["max_token"],
)

except Exception as e:
raise Exception(f"gen_config: Failed to get {provider} LLM") from e
elif kwargs["provider"] == "vertex":
logger.warning(
"Support of Vertex Gemini LLMs is still in experimental stage, use with caution"
)
service_account_path = os.path.expanduser(cfg["VERTEX_SERVICE_ACCOUNT_PATH"])
if not os.path.exists(service_account_path):
raise FileNotFoundError(
f"Google Cloud Service Account file not found: {service_account_path}"
)
try:
credentials = service_account.Credentials.from_service_account_file(
service_account_path
)
llm: LLM = Vertex(
model=kwargs["model"],
project=credentials.project_id,
credentials=credentials,
max_tokens=kwargs["max_token"],
)

except Exception as e:
raise Exception(f"gen_config: Failed to get {provider} LLM") from e
elif kwargs["provider"] == "vertexanthropic":
service_account_path = os.path.expanduser(cfg["VERTEX_SERVICE_ACCOUNT_PATH"])
if not os.path.exists(service_account_path):
raise FileNotFoundError(
f"Google Cloud Service Account file not found: {service_account_path}"
)
try:
credentials = service_account.Credentials.from_service_account_file(
service_account_path,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
llm: LLM = VertexAnthropicWithCredentials(
model=kwargs["model"],
project_id=credentials.project_id,
credentials=credentials,
region=cfg["VERTEX_REGION"],
max_tokens=kwargs["max_token"],
)

except Exception as e:
raise Exception(f"gen_config: Failed to get {provider} LLM") from e
else:
raise ValueError(f"gen_config: Invalid provider: {provider}")

try:
llm: LLM = LLM_func(
model=kwargs["model"],
api_key=api_key_cfg,
max_tokens=kwargs["max_token"],
)
_ = llm.complete("Say 'Hi'")

except Exception as e:
raise Exception("gen_config: Failed to get LLM") from e
raise Exception(
f"gen_config: Failed to complete LLM chat for {provider}"
) from e

return llm

Expand Down
58 changes: 44 additions & 14 deletions src/mage_rtl/token_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
from llama_index.core.llms.llm import LLM
from llama_index.llms.anthropic import Anthropic
from llama_index.llms.openai import OpenAI
from llama_index.llms.vertex import Vertex
from pydantic import BaseModel
from vertexai.preview.generative_models import GenerativeModel

from .gen_config import get_exp_setting
from .log_utils import get_logger
from .utils import reformat_json_string

logger = get_logger(__name__)

Expand Down Expand Up @@ -70,14 +73,23 @@ def __str__(self) -> str:
class TokenCost(BaseModel):
"""Token cost of an LLM call"""

in_token_cost_per_token: float
out_token_cost_per_token: float
in_token_cost_per_token: float = 0.0
out_token_cost_per_token: float = 0.0


token_costs = {
"claude-3-5-sonnet-20241022": TokenCost(
in_token_cost_per_token=3.0 / 1000000, out_token_cost_per_token=15.0 / 1000000
),
"claude-3-5-sonnet@20241022": TokenCost(
in_token_cost_per_token=3.0 / 1000000, out_token_cost_per_token=15.0 / 1000000
),
"claude-3-7-sonnet-20250219": TokenCost(
in_token_cost_per_token=3.0 / 1000000, out_token_cost_per_token=15.0 / 1000000
),
"claude-3-7-sonnet@20250219": TokenCost(
in_token_cost_per_token=3.0 / 1000000, out_token_cost_per_token=15.0 / 1000000
),
"gpt-4o-2024-08-06": TokenCost(
in_token_cost_per_token=2.5 / 1000000, out_token_cost_per_token=10.0 / 1000000
),
Expand All @@ -93,6 +105,9 @@ class TokenCost(BaseModel):
"gemini-1.5-pro-002": TokenCost(
in_token_cost_per_token=1.25 / 1000000, out_token_cost_per_token=5.0 / 1000000
),
"gemini-2.0-flash-001": TokenCost(
in_token_cost_per_token=0.1 / 1000000, out_token_cost_per_token=0.4 / 1000000
),
}


Expand All @@ -105,16 +120,33 @@ def __init__(self, llm: LLM) -> None:
self.token_cnts_lock = asyncio.Lock()
self.cur_tag = ""
self.max_parallel_requests: int = 10
self.enable_reformat_json = isinstance(llm, Vertex)
model = llm.metadata.model_name
if isinstance(llm, OpenAI):
self.encoding = tiktoken.encoding_for_model(model)
elif isinstance(llm, Anthropic):
self.encoding = llm.tokenizer
elif isinstance(llm, Vertex):
assert llm.model.startswith(
"gemini"
), f"Non-gemini Vertex model is not supported: {llm.model}"
assert isinstance(llm._client, GenerativeModel)

class VertexEncoding:
def __init__(self, client: GenerativeModel):
self.client = client

def encode(self, text: str) -> List[str]:
token_len = self.client.count_tokens(text).total_tokens
return ["placeholder" for _ in range(token_len)]

self.encoding = VertexEncoding(llm._client)
self.activate_structure_output = True
else:
raise Exception(f"gen_config: No tokenizer for model {model}")
logger.info(f"Found tokenizer for model '{model}'")
self.token_cost = token_costs[model] if model in token_costs else None
if self.token_cost is None:
self.token_cost = token_costs[model] if model in token_costs else TokenCost()
if self.token_cost == TokenCost():
logger.warning(
f"Cannot find token cost for model '{model}' in record. Won't display cost in USD"
)
Expand Down Expand Up @@ -147,6 +179,8 @@ def count_chat(
out_token_cnt = self.count(response.message.content)
token_cnt = TokenCount(in_token_cnt=in_token_cnt, out_token_cnt=out_token_cnt)
self.token_cnts[self.cur_tag].append(token_cnt)
if self.enable_reformat_json:
response.message.content = reformat_json_string(response.message.content)
return (response, token_cnt)

async def count_achat(
Expand All @@ -165,6 +199,8 @@ async def count_achat(
token_cnt = TokenCount(in_token_cnt=in_token_cnt, out_token_cnt=out_token_cnt)
async with self.token_cnts_lock:
self.token_cnts[self.cur_tag].append(token_cnt)
if self.enable_reformat_json:
response.message.content = reformat_json_string(response.message.content)
return (response, token_cnt)

async def count_achat_batch(
Expand Down Expand Up @@ -284,11 +320,6 @@ def count_chat(
)
response = llm.chat(
messages,
extra_headers=(
{"anthropic-beta": "prompt-caching-2024-07-31"}
if self.enable_cache
else {}
),
top_p=settings.top_p,
temperature=settings.temperature,
)
Expand All @@ -309,6 +340,8 @@ def count_chat(
),
)
self.token_cnts[self.cur_tag].append(token_cnt)
if self.enable_reformat_json:
response.message.content = reformat_json_string(response.message.content)
return (response, token_cnt)

async def count_achat(
Expand All @@ -321,11 +354,6 @@ async def count_achat(
)
response = await llm.achat(
messages,
extra_headers=(
{"anthropic-beta": "prompt-caching-2024-07-31"}
if self.enable_cache
else {}
),
top_p=settings.top_p,
temperature=settings.temperature,
)
Expand All @@ -347,6 +375,8 @@ async def count_achat(
)
async with self.token_cnts_lock:
self.token_cnts[self.cur_tag].append(token_cnt)
if self.enable_reformat_json:
response.message.content = reformat_json_string(response.message.content)
return (response, token_cnt)

def log_token_stats(self) -> None:
Expand Down
60 changes: 60 additions & 0 deletions src/mage_rtl/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,66 @@
import re

import anthropic
from llama_index.llms.anthropic import Anthropic


def add_lineno(file_content: str) -> str:
lines = file_content.split("\n")
ret = ""
for i, line in enumerate(lines):
ret += f"{i+1}: {line}\n"
return ret


def reformat_json_string(output: str) -> str:
# in gemini, the output has markdown surrounding the json string
# like ```json ... ```
# we need to remove the markdown
# remove by using regex between ```json and ```
pattern = r"```json(.*?)```"
match = re.search(pattern, output, re.DOTALL)
if match:
return match.group(1).strip()

pattern = r"```xml(.*?)```"
match = re.search(pattern, output, re.DOTALL)
if match:
return match.group(1).strip()

return output.strip()


class VertexAnthropicWithCredentials(Anthropic):
def __init__(self, credentials, **kwargs):
"""
In addition to all parameters accepted by Anthropic, this class accepts a
new parameter `credentials` that will be passed to the underlying clients.
"""
# Pop parameters that determine client type so we can reuse them in our branch.
region = kwargs.get("region")
project_id = kwargs.get("project_id")
aws_region = kwargs.get("aws_region")

# Call the parent initializer; this sets up a default _client and _aclient.
super().__init__(**kwargs)

# If using AnthropicVertex (i.e., region and project_id are provided and aws_region is None),
# override the _client and _aclient with the additional credentials parameter.
if region and project_id and not aws_region:
self._client = anthropic.AnthropicVertex(
region=region,
project_id=project_id,
credentials=credentials, # extra argument
timeout=self.timeout,
max_retries=self.max_retries,
default_headers=kwargs.get("default_headers"),
)
self._aclient = anthropic.AsyncAnthropicVertex(
region=region,
project_id=project_id,
credentials=credentials, # extra argument
timeout=self.timeout,
max_retries=self.max_retries,
default_headers=kwargs.get("default_headers"),
)
# Optionally, you could add similar overrides for the aws_region branch if needed.
9 changes: 5 additions & 4 deletions tests/test_top_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import json
import os
import time
from datetime import timedelta
from typing import Any, Dict
Expand All @@ -22,8 +21,10 @@


args_dict = {
"provider": "anthropic",
"model": "claude-3-5-sonnet-20241022",
"provider": "vertexanthropic",
"model": "claude-3-7-sonnet@20250219",
# "model": "gemini-2.0-flash-001",
# "model": "claude-3-7-sonnet-20250219",
# "model": "gpt-4o-2024-08-06",
# "filter_instance": "^(Prob070_ece241_2013_q2|Prob151_review2015_fsm)$",
"filter_instance": "^(Prob011_norgate)$",
Expand All @@ -36,7 +37,7 @@
"top_p": 0.95,
"max_token": 8192,
"use_golden_tb_in_mage": True,
"key_cfg_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "key.cfg"),
"key_cfg_path": "./key.cfg",
}


Expand Down