From 0574c60d3428b863ebca1b7a2184de58fda0a081 Mon Sep 17 00:00:00 2001 From: Jacobsolawetz Date: Thu, 1 Aug 2024 18:09:22 -0500 Subject: [PATCH] Wire DPO (#67) * route dpos through alignment routine * dpo version bump --- arcee/__init__.py | 2 +- arcee/api.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/arcee/__init__.py b/arcee/__init__.py index bab8e2a..058278a 100644 --- a/arcee/__init__.py +++ b/arcee/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.3.4" +__version__ = "1.3.5" import os diff --git a/arcee/api.py b/arcee/api.py index e0d1a41..5ae6846 100644 --- a/arcee/api.py +++ b/arcee/api.py @@ -349,29 +349,34 @@ def corpus_status(corpus: str) -> Dict[str, str]: def start_alignment( alignment_name: str, - qa_set: str, + qa_set: Optional[str] = None, pretrained_model: Optional[str] = None, merging_model: Optional[str] = None, alignment_model: Optional[str] = None, hf_model: Optional[str] = None, target_compute: Optional[str] = None, capacity_id: Optional[str] = None, - full_or_peft: Optional[str] = "full", + alignment_type: str = "sft", + full_or_peft: Optional[str] = "full" ) -> Dict[str, str]: """ Start the alignment of a model. Args: alignment_name (str): The name of the alignment job. - qa_set (str): The name of the QA set to use. + qa_set (Optional[str]): The name of the QA set to use. Required if alignment_type is "sft". pretrained_model (Optional[str]): The name of the pretrained model to use, if any. merging_model (Optional[str]): The name of the merging model to use, if any. alignment_model (Optional[str]): The name of the final alignment model to use, if any. + hf_model (Optional[str]): The name of the Hugging Face model to use, if any. target_compute (Optional[str]): The name of the compute to use, e.g., "g5.2xlarge" or "capacity". If omitted, the default compute will be used. capacity_id (Optional[str]): The name of the capacity block ID to use. If omitted, an instance will be launched to perform training. + alignment_type (str): The type of alignment to perform. Can be "sft" or "dpo". Defaults to "sft". """ + if alignment_type == "sft" and qa_set is None: + raise ValueError("qa_set is required when alignment_type is 'sft'") data = { "alignment_name": alignment_name, @@ -383,6 +388,7 @@ def start_alignment( "hf_model": hf_model, "target_compute": target_compute, "capacity_id": capacity_id, + "alignment_type": alignment_type, } # Assuming make_request is a function that handles the request, it's called here