diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index c013f2e..269bd2d 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -4,26 +4,33 @@ on: release: types: [created] +permissions: + contents: read + pages: write + id-token: write + +concurrency: + group: "publish" + cancel-in-progress: false + jobs: publish: runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.11"] - fail-fast: false + steps: - name: checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 - - name: apt-get update - run: sudo apt-get update -y + - name: pages + uses: actions/configure-pages@v5 + - name: set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: - python-version: ${{ matrix.python-version }} + python-version-file: ".python-version" cache: "pip" cache-dependency-path: "pyproject.toml" @@ -33,8 +40,18 @@ jobs: - name: build run: inv build - - name: publish + - name: publish pypi env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.TWINE_TOKEN }} - run: inv publish \ No newline at end of file + run: inv publish + + - name: generate docs + run: inv docs + + - uses: actions/upload-pages-artifact@v3 + with: + path: _site/arcee + + - id: publish-docs + uses: actions/deploy-pages@v4 diff --git a/.gitignore b/.gitignore index d4e33b8..e02c256 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Generated docs +_site + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -133,4 +136,4 @@ EleutherAI *.csv *.jsonl -env \ No newline at end of file +env diff --git a/README.md b/README.md index 58d9603..bcbdfe9 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,20 @@ -# Arcee Client Docs +# Arcee Python Client -The Arcee client for executing domain-adpated language model routines +> The Arcee Python client allows you to manage CPT, SFT, DPO, and Merge models on the Arcee Platform. + +This client may be used as a CLI by invoking `arcee` from the terminal, or as an SDK for programmatic use by `import arcee` in Python. + +Learn more at https://docs.arcee.ai ## Installation ``` -pip install arcee-py +pip install --upgrade arcee-py ``` ## Authenticating -Your Arcee API key is obtained at app.arcee.ai +Your Arcee API key is obtained at https://app.arcee.ai In bash: @@ -68,9 +72,9 @@ NOTE: you will need to set `HUGGINGFACE_TOKEN` in your environment to use this f ``` arcee.api.upload_hugging_face_dataset_qa_pairs( - "my_qa_pairs", - hf_dataset_id="org/dataset", - dataset_split="train", + "my_qa_pairs", + hf_dataset_id="org/dataset", + dataset_split="train", data_format="chatml" ) ``` diff --git a/arcee/__init__.py b/arcee/__init__.py index a6d2877..6a7cfa0 100644 --- a/arcee/__init__.py +++ b/arcee/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.2.2" +__version__ = "1.3.0" import os @@ -28,10 +28,8 @@ if not config.ARCEE_API_KEY: # We check this because it's impossible the user imported arcee, _then_ set the env, then imported again + # We don't want to block or prompt here because this will interfere with CLI usage config.ARCEE_API_KEY = os.getenv("ARCEE_API_KEY", "") - while not config.ARCEE_API_KEY: - config.ARCEE_API_KEY = input("ARCEE_API_KEY not found in environment. Please input api key: ") - os.environ["ARCEE_API_KEY"] = config.ARCEE_API_KEY __all__ = [ "upload_docs", diff --git a/arcee/cli.py b/arcee/cli.py deleted file mode 100644 index 3dd0515..0000000 --- a/arcee/cli.py +++ /dev/null @@ -1,340 +0,0 @@ -from itertools import groupby -from pathlib import Path -from typing import List, Optional - -import typer -from click import ClickException as ArceeException -from rich.console import Console -from rich.table import Table -from typing_extensions import Annotated - -import arcee.api -from arcee import DALM -from arcee.cli_handler import UploadHandler, WeightsDownloadHandler - -console = Console() - -cli = typer.Typer() -"""Arcee CLI""" - - -# FIXME: train_dalm seems to no longer exist... -# @cli.command() -# def train( -# name: Annotated[str, typer.Argument(help="Name of the model")], -# context: Annotated[Optional[str], typer.Option(help="Name of the context")] = None, -# instructions: Annotated[Optional[str], typer.Option(help="Instructions for the model")] = None, -# generator: Annotated[str, typer.Option(help="Generator type")] = "Command", -# ) -> None: -# # name: str, context: Optional[str] = None, instructions: Optional[str] = None, generator: str = "Command" -# """Train a model - -# Args: -# name (str): Name of the model -# context (str): Name of the context -# instructions (str): Instructions for the model -# generator (str): Generator type. Defaults to "Command". -# """ - -# try: -# train_dalm(name, context, instructions, generator) -# typer.secho(f"✅ Model {name} set for training.") -# except Exception as e: -# raise ArceeException( -# message=f"Error training model: {e}", -# ) from e - - -@cli.command() -def generate( - name: Annotated[str, typer.Argument(help="Model name")], - query: Annotated[str, typer.Option(help="Query string")], - size: Annotated[int, typer.Option(help="Size of the response")] = 3, -) -> None: - """Generate from model - - Args: - name (str): Name of the model - query (str): Query string - size (int): Size of the response. Defaults to 3. - - """ - - try: - dalm = DALM(name=name) - resp = dalm.generate(query=query, size=size) - typer.secho(resp) - except Exception as e: - raise ArceeException(message=f"Error generating: {e}") from e - - -@cli.command() -def retrieve( - name: Annotated[str, typer.Argument(help="Model name")], - query: Annotated[str, typer.Option(help="Query string")], - size: Annotated[int, typer.Option(help="Size")] = 3, -) -> None: - """Retrieve from model - - Args: - name (str): Name of the model - query (str): Query string - size (int): Size of the response. Defaults to 3. - - """ - try: - dalm = DALM(name=name) - resp = dalm.retrieve(query=query, size=size) - typer.secho(resp) - except Exception as e: - raise ArceeException(message=f"Error retrieving: {e}") from e - - -upload = typer.Typer(help="Upload data to Arcee platform") - - -@upload.command() -def context( - name: Annotated[str, typer.Argument(help="Name of the context")], - file: Annotated[ - Optional[List[Path]], - typer.Option(help="Path to a document", exists=True, file_okay=True, dir_okay=False, readable=True), - ] = None, - doc_name: Annotated[ - str, - typer.Option(help="Column/key representing the doc name. Used if file is jsonl or csv", exists=True), - ] = "name", - doc_text: Annotated[ - str, - typer.Option(help="Column/key representing the doc text. Used if file is jsonl or csv", exists=True), - ] = "text", - directory: Annotated[ - Optional[List[Path]], - typer.Option( - help="Path to a directory", - exists=True, - file_okay=False, - dir_okay=True, - ), - ] = None, - chunk_size: Annotated[ - int, typer.Option(help="Specify the chunk size in megabytes (MB) to limit memory usage during file uploads.") - ] = 512, -) -> None: - """Upload document(s) to context. If a directory is provided, all valid files in the directory will be uploaded. - At least one of file or directory must be provided. - - If you are using CSV or jsonl file(s), every key/column in your dataset that isn't that of `doc_name` and `doc_text` - will be uploaded as extra metadata fields with your doc. These can be used for filtering on generation and retrieval - - Args: - name (str): Name of the context - file (Path): Path to the file. - directory (Path): Path to the directory. - chunk_size (int): The chunk size in megabytes (MB) to limit memory usage during file uploads. - doc_name (str): The name of the column/key representing the doc name. Used for csv/jsonl - doc_text (str): The name of the column/key representing the doc text/content. Used for csv/jsonl - """ - if not file and not directory: - raise typer.BadParameter("At least one file or directory must be provided") - - if file is None: - file = [] - - if directory is None: - directory = [] - - file.extend(directory) - - try: - resp = UploadHandler.handle_doc_upload(name, file, chunk_size, doc_name, doc_text) - typer.secho(resp) - except Exception as e: - raise ArceeException(message=f"Error uploading document(s): {e}") from e - - -@upload.command(hidden=True) # TODO: - remove hidden=True when vocabulary upload is implemented -def vocabulary( - name: Annotated[str, typer.Argument(help="Name of the context")], - file: Annotated[ - Optional[List[Path]], - typer.Option(help="Path to a document", exists=True, file_okay=True, dir_okay=False, readable=True), - ] = None, - directory: Annotated[ - Optional[List[Path]], - typer.Option( - help="Path to a directory", - exists=True, - file_okay=False, - dir_okay=True, - ), - ] = None, -) -> None: - """Upload a vocabulary file - - Args: - name (str): Name of the vocabulary - file (Path): Path to the file - """ - if not file and not directory: - raise typer.BadParameter("Atleast one file or directory must be provided") - - if directory is None: - directory = [] - - if file is None: - file = [] - - file.extend(directory) - - docs = [] - for f in file: - docs.append({"doc_name": f.name, "doc_text": f.read_text()}) - - # TODO: upload_vocabulary - # Uploadhandler.handle_vocabulary_upload(name, data) - - -cli.add_typer(upload, name="upload") - - -cpt = typer.Typer(help="Manage CPT") - - -@cpt.command(name="list") -def list_cpts() -> None: - """List all CPTs""" - try: - result = arcee.api.list_pretrainings() - - table = Table("Name", "Status", "Base Generator", "Last Updated", title="List CPTs") - - key_func = lambda x: x["processing_state"] # noqa: E731 - grouped_data = {key: list(group) for key, group in groupby(result, key_func)} - - captions = [] - - for key, group in grouped_data.items(): - if key == "failed": - captions.append(typer.style(f"Failed: {len(group)}", fg="red")) - elif key == "completed": - captions.append(typer.style(f"Completed: {len(group)}", fg="green")) - elif key == "processing": - captions.append(typer.style(f"Processing: {len(group)}", fg="yellow")) - elif key == "pending": - captions.append(typer.style(f"Pending: {len(group)}", fg="blue")) - - table.add_section() - - for cpt in list(group): - table.add_row( - cpt["name"], - cpt["status"], - cpt["base_generator"], - cpt.get("updated_at") or cpt.get("created_at", "-"), - ) - - if len(captions) > 0: - table.caption = " | ".join(captions) - console.print(table) - except Exception as e: - raise ArceeException(message=f"Error listing CPTs: {e}") from e - - -@cpt.command(name="download") -def download_cpt_weights( - name: Annotated[ - str, - typer.Option(help="Name of the CPT model to download weights for", prompt="Enter the name of the CPT model"), - ], - out: Annotated[ - Optional[Path], - typer.Option(help="Path to download file to", file_okay=True, dir_okay=False, readable=True), - ] = None, -) -> None: - """Download CPT weights""" - WeightsDownloadHandler.handle_weights_download("pretraining", name, out) - - -cli.add_typer(cpt, name="cpt") - - -sft = typer.Typer(help="Manage SFT") - - -@sft.command(name="download") -def download_sft_weights( - name: Annotated[ - str, - typer.Option(help="Name of the SFT model to download weights for", prompt="Enter the name of the SFT model"), - ], - out: Annotated[ - Optional[Path], - typer.Option(help="Path to download file to", file_okay=True, dir_okay=False, readable=True), - ] = None, -) -> None: - """Download SFT weights""" - WeightsDownloadHandler.handle_weights_download("alignment", name, out) - - -cli.add_typer(sft, name="sft") - - -retriever = typer.Typer(help="Manage Retrievers") - - -@retriever.command(name="download") -def download_retriever_weights( - name: Annotated[ - str, - typer.Option( - help="Name of the retriever model to download weights for", prompt="Enter the name of the retriever model" - ), - ], - out: Annotated[ - Optional[Path], - typer.Option(help="Path to download file to", file_okay=True, dir_okay=False, readable=True), - ] = None, -) -> None: - """Download Retriever weights""" - WeightsDownloadHandler.handle_weights_download("retriever", name, out) - - -cli.add_typer(retriever, name="retriever") - -merging = typer.Typer(help="Manage Merging") - - -@merging.command(name="download") -def download_merging_weights( - name: Annotated[ - str, - typer.Option( - help="Name of the merging model to download weights for", prompt="Enter the name of the merging model" - ), - ], - out: Annotated[ - Optional[Path], - typer.Option(help="Path to download file to", file_okay=True, dir_okay=False, readable=True), - ] = None, -) -> None: - """Download Merging weights""" - WeightsDownloadHandler.handle_weights_download("merging", name, out) - - -cli.add_typer(merging, name="merging") - - -@cli.command() -def org() -> None: - """Prints the current org""" - try: - result = arcee.api.get_current_org() - console.print(f"Current org: {result}") - except Exception as e: - console.print_exception() - raise ArceeException(message=f"Error getting current org: {e}") from e - - -if __name__ == "__main__": - cli() diff --git a/arcee/cli/__init__.py b/arcee/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/arcee/cli/app.py b/arcee/cli/app.py new file mode 100644 index 0000000..c6dfb67 --- /dev/null +++ b/arcee/cli/app.py @@ -0,0 +1,108 @@ +from typing import Optional + +import typer +from rich.console import Console +from rich.prompt import Prompt +from typing_extensions import Annotated + +import arcee.api +from arcee.cli.commands.cpt import cpt +from arcee.cli.commands.merging import merging +from arcee.cli.commands.retriever import retriever +from arcee.cli.commands.sft import sft +from arcee.cli.errors import ArceeException +from arcee.cli.typer import ArceeTyper +from arcee.config import ARCEE_API_KEY, ARCEE_API_URL, ARCEE_ORG, write_configuration_value + +console = Console() + +"""Arcee CLI""" +cli = ArceeTyper( + help=f""" + Welcome to the Arcee CLI! 🚀 + + This CLI provides a convenient way to interact with the Arcee platform. + The Arcee client is also available as a Python package for programmatic access. + {ARCEE_API_URL}/docs + """, + epilog="For more information, see our documentation at https://docs.arcee.ai", +) + +############################ +# Subcommands +############################ +cli.add_typer(cpt, name="cpt") +cli.add_typer(merging, name="merging") +cli.add_typer(retriever, name="retriever") +cli.add_typer(sft, name="sft") + + +############################ +# Top-level CLI Commands +############################ +@cli.command() +def org() -> None: + """Prints the current org""" + try: + result = arcee.api.get_current_org() + console.print(f"Current org: {result}") + except Exception as e: + console.print_exception() + raise ArceeException(message=f"Error getting current org: {e}") from e + + +@cli.command() +def configure( + org: Annotated[ + Optional[str], + typer.Option( + help="Your organization. If not provided, we will use your default organization. " + + "Defaults to the ARCEE_ORG environment variable." + ), + ] = None, + api_key: Annotated[ + Optional[str], typer.Option(help="Your API key. Defaults to the ARCEE_API_KEY environment variable.") + ] = None, + api_url: Annotated[ + Optional[str], + typer.Option(help="The URL of the Arcee API. Defaults to ARCEE_API_URL, or https://app.arcee.ai/api."), + ] = None, +) -> None: + """Write a configuration file for the Arcee SDK and CLI""" + + if ARCEE_ORG: + console.print(f"Current org: {ARCEE_ORG}") + if org: + console.print(f"Setting org to {org}") + write_configuration_value("ARCEE_ORG", org) + + if ARCEE_API_URL: + console.print(f"Current API URL: {ARCEE_API_URL}") + if api_url: + console.print(f"Setting API URL to {api_url}") + write_configuration_value("ARCEE_API_URL", api_url) + + console.print(f"API key: {'in' if ARCEE_API_KEY else 'not in'} config (file or env)") + + if api_key: + console.print("Setting API key") + write_configuration_value("ARCEE_API_KEY", api_key) + + if not ARCEE_API_KEY and not api_key: + resp = Prompt.ask( + # password=True, + prompt=""" +Enter your Arcee API key :lock: +Hit enter to leave it as is. +See https://docs.arcee.ai/getting-arcee-api-key/getting-arcee-api-key for more details. +You can also pass this at runtime with the ARCEE_API_KEY environment variable. +""", + ) + if resp: + console.print("Setting API key") + write_configuration_value("ARCEE_API_KEY", resp) + + +# Enter the CLI +if __name__ == "__main__": + cli() diff --git a/arcee/cli/commands/__init__.py b/arcee/cli/commands/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/arcee/cli/commands/cpt.py b/arcee/cli/commands/cpt.py new file mode 100644 index 0000000..2f2ee20 --- /dev/null +++ b/arcee/cli/commands/cpt.py @@ -0,0 +1,79 @@ +from itertools import groupby +from pathlib import Path +from typing import Optional + +import typer +from rich.console import Console +from rich.table import Table +from typing_extensions import Annotated + +import arcee.api +from arcee.cli.errors import ArceeException +from arcee.cli.handlers.weights import WeightsDownloadHandler +from arcee.cli.typer import ArceeTyper + +console = Console() + +cpt = ArceeTyper( + rich_markup_mode="rich", + no_args_is_help=True, + help=""" + Manage CPT models and weights on the Arcee platform + """, + epilog="For more information on CPT, see https://docs.arcee.ai/pretraining/should-i-pretrain", +) + + +@cpt.command(name="list") +def list_cpts() -> None: + """List all CPTs""" + try: + result = arcee.api.list_pretrainings() + + table = Table("Name", "Status", "Base Generator", "Last Updated", title="List CPTs") + + key_func = lambda x: x["processing_state"] # noqa: E731 + grouped_data = {key: list(group) for key, group in groupby(result, key_func)} + + captions = [] + + for key, group in grouped_data.items(): + if key == "failed": + captions.append(typer.style(f"Failed: {len(group)}", fg="red")) + elif key == "completed": + captions.append(typer.style(f"Completed: {len(group)}", fg="green")) + elif key == "processing": + captions.append(typer.style(f"Processing: {len(group)}", fg="yellow")) + elif key == "pending": + captions.append(typer.style(f"Pending: {len(group)}", fg="blue")) + + table.add_section() + + for cpt in list(group): + table.add_row( + cpt["name"], + cpt["status"], + cpt["base_generator"], + cpt.get("updated_at") or cpt.get("created_at", "-"), + ) + + if len(captions) > 0: + table.caption = " | ".join(captions) + console.print(table) + except Exception as e: + raise ArceeException(message=f"Error listing CPTs: {e}") from e + + +@cpt.command(name="download") +def download_cpt_weights( + name: Annotated[ + str, + typer.Option(help="Name of the CPT model to download weights for", prompt="Enter the name of the CPT model"), + ], + out: Annotated[ + Optional[Path], + typer.Option(help="Path to download file to", file_okay=True, dir_okay=False, readable=True), + ] = None, +) -> None: + """Download CPT weights""" + WeightsDownloadHandler.handle_weights_download("pretraining", name, out) diff --git a/arcee/cli/commands/merging.py b/arcee/cli/commands/merging.py new file mode 100644 index 0000000..0c9a955 --- /dev/null +++ b/arcee/cli/commands/merging.py @@ -0,0 +1,29 @@ +from pathlib import Path +from typing import Optional + +import typer +from typing_extensions import Annotated + +from arcee.cli.handlers.weights import WeightsDownloadHandler +from arcee.cli.typer import ArceeTyper + +merging = ArceeTyper( + help="Manage Merging", epilog="For more information on merging, see https://docs.arcee.ai/merging/should-i-merge" +) + + +@merging.command(name="download") +def download_merging_weights( + name: Annotated[ + str, + typer.Option( + help="Name of the merging model to download weights for", prompt="Enter the name of the merging model" + ), + ], + out: Annotated[ + Optional[Path], + typer.Option(help="Path to download file to", file_okay=True, dir_okay=False, readable=True), + ] = None, +) -> None: + """Download Merging weights""" + WeightsDownloadHandler.handle_weights_download("merging", name, out) diff --git a/arcee/cli/commands/retriever.py b/arcee/cli/commands/retriever.py new file mode 100644 index 0000000..84e17b6 --- /dev/null +++ b/arcee/cli/commands/retriever.py @@ -0,0 +1,89 @@ +from pathlib import Path +from typing import List, Optional + +import typer +from typing_extensions import Annotated + +from arcee.cli.errors import ArceeException +from arcee.cli.handlers.upload import UploadHandler +from arcee.cli.handlers.weights import WeightsDownloadHandler +from arcee.cli.typer import ArceeTyper + +retriever = ArceeTyper(help="Manage Retrievers") + + +@retriever.command(name="download") +def download_retriever_weights( + name: Annotated[ + str, + typer.Option( + help="Name of the retriever model to download weights for", prompt="Enter the name of the retriever model" + ), + ], + out: Annotated[ + Optional[Path], + typer.Option(help="Path to download file to", file_okay=True, dir_okay=False, readable=True), + ] = None, +) -> None: + """Download Retriever weights""" + WeightsDownloadHandler.handle_weights_download("retriever", name, out) + + +@retriever.command(name="upload-context", short_help="Upload document(s) to context") +def upload_context( + name: Annotated[str, typer.Argument(help="Name of the context")], + file: Annotated[ + Optional[List[Path]], + typer.Option(help="Path to a document", exists=True, file_okay=True, dir_okay=False, readable=True), + ] = None, + doc_name: Annotated[ + str, + typer.Option(help="Column/key representing the doc name. Used if file is jsonl or csv", exists=True), + ] = "name", + doc_text: Annotated[ + str, + typer.Option(help="Column/key representing the doc text. Used if file is jsonl or csv", exists=True), + ] = "text", + directory: Annotated[ + Optional[List[Path]], + typer.Option( + help="Path to a directory", + exists=True, + file_okay=False, + dir_okay=True, + ), + ] = None, + chunk_size: Annotated[ + int, typer.Option(help="Specify the chunk size in megabytes (MB) to limit memory usage during file uploads.") + ] = 512, +) -> None: + """Upload document(s) to context. If a directory is provided, all valid files in the directory will be uploaded. + At least one of file or directory must be provided. + + If you are using CSV or jsonl file(s), every key/column in your dataset that isn't that of `doc_name` and `doc_text` + will be uploaded as extra metadata fields with your doc. These can be used for filtering on generation and retrieval + + Args: + name (str): Name of the context + file (Path): Path to the file. + directory (Path): Path to the directory. + chunk_size (int): The chunk size in megabytes (MB) to limit memory usage during file uploads. + doc_name (str): The name of the column/key representing the doc name. Used for csv/jsonl + doc_text (str): The name of the column/key representing the doc text/content. Used for csv/jsonl + """ + if not file and not directory: + raise typer.BadParameter("At least one file or directory must be provided") + + if file is None: + file = [] + + if directory is None: + directory = [] + + file.extend(directory) + + try: + resp = UploadHandler.handle_doc_upload(name, file, chunk_size, doc_name, doc_text) + typer.secho(resp) + except Exception as e: + raise ArceeException(message=f"Error uploading document(s): {e}") from e diff --git a/arcee/cli/commands/sft.py b/arcee/cli/commands/sft.py new file mode 100644 index 0000000..e95c106 --- /dev/null +++ b/arcee/cli/commands/sft.py @@ -0,0 +1,27 @@ +from pathlib import Path +from typing import Optional + +import typer +from typing_extensions import Annotated + +from arcee.cli.handlers.weights import WeightsDownloadHandler +from arcee.cli.typer import ArceeTyper + +sft = ArceeTyper( + help="Manage SFT", epilog="For more information on SFT, see https://docs.arcee.ai/aligning/should-i-align-my-model" +) + + +@sft.command(name="download") +def download_sft_weights( + name: Annotated[ + str, + typer.Option(help="Name of the SFT model to download weights for", prompt="Enter the name of the SFT model"), + ], + out: Annotated[ + Optional[Path], + typer.Option(help="Path to download file to", file_okay=True, dir_okay=False, readable=True), + ] = None, +) -> None: + """Download SFT weights""" + WeightsDownloadHandler.handle_weights_download("alignment", name, out) diff --git a/arcee/cli/errors.py b/arcee/cli/errors.py new file mode 100644 index 0000000..f79c0d3 --- /dev/null +++ b/arcee/cli/errors.py @@ -0,0 +1,5 @@ +from click.exceptions import ClickException + + +class ArceeException(ClickException): + pass diff --git a/arcee/cli_handler.py b/arcee/cli/handlers/upload.py similarity index 74% rename from arcee/cli_handler.py rename to arcee/cli/handlers/upload.py index 9d6b5c5..5dac4e6 100644 --- a/arcee/cli_handler.py +++ b/arcee/cli/handlers/upload.py @@ -1,18 +1,17 @@ from importlib.util import find_spec from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List import typer -from click import ClickException as ArceeException from rich.console import Console -from rich.progress import DownloadColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, TransferSpeedColumn +from rich.progress import Progress, SpinnerColumn, TextColumn from arcee import upload_docs -from arcee.api import download_weights, model_weight_types +from arcee.cli.errors import ArceeException from arcee.schemas.doc import Doc if not find_spec("pandas"): - raise ModuleNotFoundError("Cannot find pandas. Please run `pip install 'arcee-py[cli]'` for cli support") + raise ModuleNotFoundError("Cannot find pandas. Please run `pip install --upgrade 'arcee-py[cli]'` for cli support") import pandas as pd @@ -165,45 +164,3 @@ def handle_doc_upload( ) progress.update(uploading, description=f"✅ Uploaded {len(paths)} document(s) to context {name}") return resp - - -class WeightsDownloadHandler: - """Download weights from Arcee platform""" - - @classmethod - def handle_weights_download(cls, kind: model_weight_types, id_or_name: str, path: Optional[Path] = None) -> None: - """Download weights from Arcee platform - - Args: - kind model_weight_types: Type of model weights. - id_or_name str: Name or ID of the model. - path Path: Path to save the weights. - """ - try: - out = path or Path.cwd() / f"{id_or_name}.tar.gz" - console.print(f"Downloading {kind} model weights for {id_or_name} to {out}") - - with open(out, "wb") as f: - with download_weights(kind, id_or_name) as response: - response.raise_for_status() - size = int(response.headers.get("Content-Length", 0)) - - with Progress( - *Progress.get_default_columns(), - DownloadColumn(), - TimeElapsedColumn(), - TransferSpeedColumn(), - transient=True, - ) as progress: - task = progress.add_task( - f"[blue]Downloading {id_or_name} weights...", total=size if size > 0 else None - ) - - for chunk in response.iter_content(chunk_size=8192): - f.write(chunk) - progress.update(task, advance=len(chunk)) - - console.print(f"Downloaded {out} in {progress.get_time()} seconds") - except Exception as e: - console.print_exception() - raise ArceeException(message=f"Error downloading {kind} weights: {e}") from e diff --git a/arcee/cli/handlers/weights.py b/arcee/cli/handlers/weights.py new file mode 100644 index 0000000..e078558 --- /dev/null +++ b/arcee/cli/handlers/weights.py @@ -0,0 +1,52 @@ +from pathlib import Path +from typing import Optional + +from rich.console import Console +from rich.progress import DownloadColumn, Progress, TimeElapsedColumn, TransferSpeedColumn + +from arcee.api import download_weights, model_weight_types +from arcee.cli.errors import ArceeException + +console = Console() + + +class WeightsDownloadHandler: + """Download weights from Arcee platform""" + + @classmethod + def handle_weights_download(cls, kind: model_weight_types, id_or_name: str, path: Optional[Path] = None) -> None: + """Download weights from Arcee platform + + Args: + kind model_weight_types: Type of model weights. + id_or_name str: Name or ID of the model. + path Path: Path to save the weights. + """ + try: + out = path or Path.cwd() / f"{id_or_name}.tar.gz" + console.print(f"Downloading {kind} model weights for {id_or_name} to {out}") + + with open(out, "wb") as f: + with download_weights(kind, id_or_name) as response: + response.raise_for_status() + size = int(response.headers.get("Content-Length", 0)) + + with Progress( + *Progress.get_default_columns(), + DownloadColumn(), + TimeElapsedColumn(), + TransferSpeedColumn(), + transient=True, + ) as progress: + task = progress.add_task( + f"[blue]Downloading {id_or_name} weights...", total=size if size > 0 else None + ) + + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + progress.update(task, advance=len(chunk)) + + console.print(f"Downloaded {out} in {progress.get_time()} seconds") + except Exception as e: + console.print_exception() + raise ArceeException(message=f"Error downloading {kind} weights: {e}") from e diff --git a/arcee/cli/typer.py b/arcee/cli/typer.py new file mode 100644 index 0000000..4287532 --- /dev/null +++ b/arcee/cli/typer.py @@ -0,0 +1,24 @@ +from typing import Any + +import typer + + +class ArceeTyper(typer.Typer): + """ + Custom Typer class for Arcee CLI + + Use this instead of `typer.Typer` to set default values for the CLI + and consistency across all commands. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + # UX + kwargs.setdefault("no_args_is_help", True) + + # Security + kwargs.setdefault("pretty_exceptions_show_locals", False) + + # Formatting + kwargs.setdefault("rich_markup_mode", "rich") + + super().__init__(*args, **kwargs) diff --git a/arcee/config.py b/arcee/config.py index b3ce732..753261c 100644 --- a/arcee/config.py +++ b/arcee/config.py @@ -1,5 +1,50 @@ import json import os +from pathlib import Path + +from typer import get_app_dir + + +def get_configuration_path() -> Path: + """Gets the configuration file path. + Uses typer/click's get_app_dir to get the configuration file path + and allows an ARCEE_CONFIG_LOCATION override. + """ + env_location = os.getenv("ARCEE_CONFIG_LOCATION") + if env_location: + env_path: Path = Path(env_location) + if not env_path.is_file(): + raise FileNotFoundError(f"Configuration file not found at {env_path}") + return env_path + + app_dir = get_app_dir("arcee") + conf_path: Path = Path(app_dir) / "config.json" + + return conf_path + + +def write_configuration_value(key: str, value: str) -> None: + """Writes a configuration value to the configuration file. + Args: + key (string): The name of the configuration variable. + value (string): The value of the configuration variable. + """ + conf_path = get_configuration_path() + config = {} + + if conf_path.is_file(): + with open(conf_path, "r") as f: + try: + config = json.load(f) + except json.JSONDecodeError: + pass + else: + conf_path.touch() + + config[key] = value + + with open(conf_path, "w") as f: + json.dump(config, f) def get_conditional_configuration_variable(key: str, default: str) -> str: @@ -13,19 +58,7 @@ def get_conditional_configuration_variable(key: str, default: str) -> str: Returns: string: The value of the conditional configuration variable. """ - - os_name = os.name - - if os_name == "nt": - default_path = os.getenv("USERPROFILE", "") + "\\arcee\\config.json" - else: - default_path = os.getenv("HOME", "") + "/.config/arcee/config.json" - - # default configuration location - conf_location = os.getenv( - "ARCEE_CONFIG_LOCATION", - default=default_path, - ) + conf_location = get_configuration_path() if os.path.exists(conf_location): with open(conf_location) as f: @@ -36,7 +69,7 @@ def get_conditional_configuration_variable(key: str, default: str) -> str: return (os.getenv(key) or config.get(key)) or default -ARCEE_API_URL = get_conditional_configuration_variable("ARCEE_API_URL", "https://api.arcee.ai") +ARCEE_API_URL = get_conditional_configuration_variable("ARCEE_API_URL", "https://app.arcee.ai/api") ARCEE_APP_URL = get_conditional_configuration_variable("ARCEE_APP_URL", "https://app.arcee.ai") ARCEE_API_KEY = get_conditional_configuration_variable("ARCEE_API_KEY", "") ARCEE_API_VERSION = get_conditional_configuration_variable("ARCEE_API_VERSION", "v2") diff --git a/pyproject.toml b/pyproject.toml index fe1bfb1..659016a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,7 @@ [build-system] -requires = ["hatchling<=1.18.0"] +requires = ["hatchling>=1.24.0, <1.25.0"] build-backend = "hatchling.build" - [project] name = "arcee-py" authors = [{name = "Jacob Solowetz", email = "jacob@arcee.ai"}, {name = "Ben Epstein", email = "ben@arcee.ai"}] @@ -10,19 +9,19 @@ readme = "README.md" requires-python = ">=3.8" license = {text = "MIT"} dynamic = ["version", "description"] -packages = [ - { include = "arcee" } - ] dependencies = [ - "requests>=2.32.2,<3.0", - "typer>=0.12.3,<1.0", - "rich>=13.7.1,<14.0", - "pyyaml>=6.0.1,<7.0", - "pydantic>=2.4.2,<3.0", - "StrEnum>=0.4.15,<1.0", - "datasets>=2.19.2,<3.0", + "requests>=2.32.2, <3.0", + "typer>=0.12.3, <1.0", + "rich>=13.7.1, <14.0", + "pyyaml>=6.0.1, <7.0", + "pydantic>=2.4.2, <3.0", + "StrEnum>=0.4.15, <1.0", + "datasets>=2.19.2, <3.0", ] +[project.scripts] +arcee = "arcee.cli.app:cli" + [tool.hatch.build.targets.wheel] packages = ["arcee"] @@ -32,16 +31,17 @@ path = "arcee/__init__.py" [project.optional-dependencies] dev = [ + "arcee-py[cli]", "black", "invoke", "mypy", + "pdoc3; python_version>='3.8'", "pytest", "pytest-cov", "pytest-env", "ruff", "types-PyYAML", "types-requests", - "pandas", "pandas-stubs" ] cli = [ @@ -58,8 +58,14 @@ line-length = 120 target-version = "py311" respect-gitignore = true line-length = 120 + +[tool.ruff.format] +quote-style = "single" + +[tool.ruff.lint] # Pyflakes, bugbear, pycodestyle, pycodestyle warnings, isort -lint.select=["TID252", "B", "F", "E", "W", "I001"] +select=["TID252", "B", "F", "E", "W", "I001"] +# extend-select = ["C901"] [tool.ruff.lint.isort] case-sensitive = true diff --git a/tasks.py b/tasks.py index 1c30cab..5a28edc 100644 --- a/tasks.py +++ b/tasks.py @@ -146,6 +146,24 @@ def build(ctx: Context) -> None: ) +@task +def docs(ctx: Context) -> None: + """docs + + Generate documentation HTML. + """ + ctx.run( + "pip install --upgrade pdoc3", + pty=True, + echo=True, + ) + ctx.run( + "python -m pdoc arcee -o _site --html --force", + pty=True, + echo=True, + ) + + @task def publish(ctx: Context) -> None: """Deploy to pypi"""