Skip to content

Commit

Permalink
Merge branch 'main' into cli_notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
Tyler-Odenthal authored Aug 19, 2024
2 parents 8c8b27a + 5df2d26 commit 912dfd5
Show file tree
Hide file tree
Showing 5 changed files with 392 additions and 3 deletions.
2 changes: 1 addition & 1 deletion arcee/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.3.4"
__version__ = "1.3.6"

import os

Expand Down
10 changes: 8 additions & 2 deletions arcee/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,29 +349,34 @@ def corpus_status(corpus: str) -> Dict[str, str]:

def start_alignment(
alignment_name: str,
qa_set: str,
qa_set: Optional[str] = None,
pretrained_model: Optional[str] = None,
merging_model: Optional[str] = None,
alignment_model: Optional[str] = None,
hf_model: Optional[str] = None,
target_compute: Optional[str] = None,
capacity_id: Optional[str] = None,
alignment_type: str = "sft",
full_or_peft: Optional[str] = "full",
) -> Dict[str, str]:
"""
Start the alignment of a model.
Args:
alignment_name (str): The name of the alignment job.
qa_set (str): The name of the QA set to use.
qa_set (Optional[str]): The name of the QA set to use. Required if alignment_type is "sft".
pretrained_model (Optional[str]): The name of the pretrained model to use, if any.
merging_model (Optional[str]): The name of the merging model to use, if any.
alignment_model (Optional[str]): The name of the final alignment model to use, if any.
hf_model (Optional[str]): The name of the Hugging Face model to use, if any.
target_compute (Optional[str]): The name of the compute to use, e.g., "g5.2xlarge" or
"capacity". If omitted, the default compute will be used.
capacity_id (Optional[str]): The name of the capacity block ID to use. If omitted, an
instance will be launched to perform training.
alignment_type (str): The type of alignment to perform. Can be "sft" or "dpo". Defaults to "sft".
"""
if alignment_type == "sft" and qa_set is None:
raise ValueError("qa_set is required when alignment_type is 'sft'")

data = {
"alignment_name": alignment_name,
Expand All @@ -383,6 +388,7 @@ def start_alignment(
"hf_model": hf_model,
"target_compute": target_compute,
"capacity_id": capacity_id,
"alignment_type": alignment_type,
}

# Assuming make_request is a function that handles the request, it's called here
Expand Down
1 change: 1 addition & 0 deletions arcee/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def write_configuration_value(key: str, value: str) -> None:
except json.JSONDecodeError:
pass
else:
conf_path.parent.mkdir(parents=True, exist_ok=True)
conf_path.touch()

config[key] = value
Expand Down
Loading

0 comments on commit 912dfd5

Please sign in to comment.