Skip to content

Commit

Permalink
small cleanup in apis
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Epstein committed Sep 22, 2023
1 parent b329865 commit b85bd44
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
6 changes: 3 additions & 3 deletions arcee/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
__version__ = "0.0.12"
__version__ = "0.0.13"


from arcee.api import get_dalm, train_dalm, upload_doc, upload_docs
from arcee.api import get_dalm, get_dalm_status, train_dalm, upload_doc, upload_docs
from arcee.config import ARCEE_API_KEY, ARCEE_APP_URL
from arcee.dalm import DALM

if not ARCEE_API_KEY:
raise Exception(f"ARCEE_API_KEY must be in the environment. You can retrieve your API key from {ARCEE_APP_URL}")


__all__ = ["upload_docs", "upload_doc", "train_dalm", "get_dalm", "DALM"]
__all__ = ["upload_docs", "upload_doc", "train_dalm", "get_dalm", "DALM", "get_dalm_status"]
14 changes: 11 additions & 3 deletions arcee/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import requests

from arcee.config import ARCEE_API_KEY, ARCEE_API_URL, ARCEE_API_VERSION, ARCEE_APP_URL
from arcee.dalm import DALM
from arcee.dalm import DALM, check_model_status


def upload_doc(context: str, doc_name: str, doc_text: str) -> dict[str, str]:
Expand Down Expand Up @@ -73,8 +73,16 @@ def train_dalm(

if response.status_code != 201:
raise Exception(f"Failed to train model. Response: {response.text}")
else:
print(f"DALM model training started - view model status at {ARCEE_APP_URL}, then arcee.get_model(" + name + ")")
status_url = f"{ARCEE_APP_URL}/arcee/models/{name}/training"
print(
f"DALM model training started - view model status at {status_url} or with `arcee.get_dalm_status({name}).\n"
f"Then, get your DALM with arcee.get_dalm({name})"
)


def get_dalm_status(id_or_name: str) -> dict[str, str]:
"""Gets the status of a DALM training job"""
return check_model_status(id_or_name)


def get_dalm(name: str) -> DALM:
Expand Down

0 comments on commit b85bd44

Please sign in to comment.