Skip to content

Commit

Permalink
pass through new generate params
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacobsolawetz committed Aug 13, 2024
1 parent 6fe04d0 commit c20d1d2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
2 changes: 1 addition & 1 deletion arcee/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.3.5"
__version__ = "1.3.6"

import os

Expand Down
20 changes: 18 additions & 2 deletions arcee/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,24 @@ def deployment_status(deployment_name: str) -> Dict[str, str]:
return make_request("get", Route.deployment + "/status", data)


def generate(deployment_name: str, query: str) -> Dict[str, str]:
data = {"deployment_name": deployment_name, "query": query}
def generate(
deployment_name: str,
query: str,
repetition_penalty: float | None = None,
top_k: int | None = None,
max_new_tokens: int | None = None,
temperature: float | None = None,
top_p: float | None = None
) -> Dict[str, str]:
data = {
"deployment_name": deployment_name,
"query": query,
"repetition_penalty": repetition_penalty,
"top_k": top_k,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p
}
return make_request("post", Route.deployment + "/generate", data)


Expand Down

0 comments on commit c20d1d2

Please sign in to comment.