From 0b92625cc9d8c2f029e0a1ff928a09f7a6bc7f55 Mon Sep 17 00:00:00 2001 From: Jacobsolawetz Date: Tue, 20 Aug 2024 17:39:35 -0500 Subject: [PATCH] pass through new generate params (#82) * pass through new generate params * multi turn * update version --- .gitignore | 1 + arcee/__init__.py | 2 +- arcee/api.py | 22 ++++++++++++++++++++-- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 3726b83..65c5c6d 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__/ *.py[cod] *$py.class .idea +*.ipynb # C extensions *.so diff --git a/arcee/__init__.py b/arcee/__init__.py index bc40fb1..a9e04de 100644 --- a/arcee/__init__.py +++ b/arcee/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.3.6" +__version__ = "1.3.7" import os diff --git a/arcee/api.py b/arcee/api.py index 2dc6b70..6207d05 100644 --- a/arcee/api.py +++ b/arcee/api.py @@ -472,8 +472,26 @@ def deployment_status(deployment_name: str) -> Dict[str, str]: return make_request("get", Route.deployment + "/status", data) -def generate(deployment_name: str, query: str) -> Dict[str, str]: - data = {"deployment_name": deployment_name, "query": query} +def generate( + deployment_name: str, + query: str | None = None, + messages: List[Dict[str, str]] | None = None, + repetition_penalty: float | None = None, + top_k: int | None = None, + max_new_tokens: int | None = None, + temperature: float | None = None, + top_p: float | None = None, +) -> Dict[str, str]: + data = { + "deployment_name": deployment_name, + "query": query, + "repetition_penalty": repetition_penalty, + "top_k": top_k, + "max_new_tokens": max_new_tokens, + "temperature": temperature, + "messages": messages, + "top_p": top_p, + } return make_request("post", Route.deployment + "/generate", data)