Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(weave): Remove temporary n completions #3312

Merged
merged 2 commits into from
Jan 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 8 additions & 63 deletions weave/trace_server/llm_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,53 +33,20 @@ def lite_llm_completion(
# This allows us to drop params that are not supported by the LLM provider
litellm.drop_params = True

if supports_n_times(inputs.model) or inputs.n == 1:
try:
res = litellm.completion(
**inputs.model_dump(exclude_none=True),
api_key=api_key,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
)
return tsi.CompletionsCreateRes(response=res.model_dump())
except Exception as e:
error_message = str(e)
error_message = error_message.replace("litellm.", "")
return tsi.CompletionsCreateRes(response={"error": error_message})

# o1 models with n > 1
results = []
try:
# get n results
for i in range(inputs.n or 1):
results.append(
litellm.completion(
**inputs.model_dump(exclude_none=True),
api_key=api_key,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
)
)
res = litellm.completion(
**inputs.model_dump(exclude_none=True),
api_key=api_key,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
)
return tsi.CompletionsCreateRes(response=res.model_dump())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move to else block; only include the specific failure in the try block

except Exception as e:
error_message = str(e)
error_message = error_message.replace("litellm.", "")
return tsi.CompletionsCreateRes(response={"error": error_message})

final_result = results[0]
for idx, result in enumerate(results):
if idx != 0:
# append choices
final_result.choices.append(result.choices[0])

# sum usage
final_result.usage = sum_dict_leaves(
[result.usage.model_dump() for result in results]
)

return tsi.CompletionsCreateRes(response=final_result.model_dump())


def get_bedrock_credentials(
model_name: str,
Expand Down Expand Up @@ -122,25 +89,3 @@ def get_bedrock_credentials(
)

return aws_access_key_id, aws_secret_access_key, aws_region_name


NO_N_TIMES_MODEL_NAMES = ("o1-mini", "o1-preview", "o1")


# if the model name contains any of these strings, we don't support n > 1
def supports_n_times(model_name: str) -> bool:
return not any(x in model_name for x in NO_N_TIMES_MODEL_NAMES)


# copied from weave/trace/weave_client.py
def sum_dict_leaves(dicts: list[dict]) -> dict:
# dicts is a list of dictionaries, that may or may not
# have nested dictionaries. Sum all the leaves that match
result: dict = {}
for d in dicts:
for k, v in d.items():
if isinstance(v, dict):
result[k] = sum_dict_leaves([result.get(k, {}), v])
elif v is not None:
result[k] = result.get(k, 0) + v
return result
Loading