Skip to content

Commit

Permalink
bug fixes for ablation
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminye committed Feb 1, 2024
1 parent 78da909 commit ec3d158
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 23 deletions.
1 change: 1 addition & 0 deletions toolkit/config.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
save_dir: "./experiment/"

ablation:
use_ablate: true
study_name: "ablate_1"

# Data Ingestion -------------------
Expand Down
30 changes: 15 additions & 15 deletions toolkit/src/pydantic_models/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class ModelConfig(BaseModel):
)

quantize: Optional[bool] = Field(False, description="Flag to enable quantization")
bitsandbytes: Optional[BitsAndBytesConfig] = Field(
bitsandbytes: BitsAndBytesConfig = Field(
None, description="Bits and Bytes configuration"
)

Expand Down Expand Up @@ -126,7 +126,7 @@ class LoraConfig(BaseModel):
lora_dropout: Optional[float] = Field(
0.1, description="The dropout probability for Lora layers"
)
target_modules: Optional[Union[List[str], str]] = Field(
target_modules: Optional[List[str]] = Field(
None, description="The names of the modules to apply Lora to"
)
fan_in_fan_out: Optional[bool] = Field(
Expand All @@ -141,12 +141,12 @@ class LoraConfig(BaseModel):
None, description="The layer indexes to transform"
)
layers_pattern: Optional[str] = Field(None, description="The layer pattern name")
rank_pattern: Optional[Dict[str, int]] = Field(
{}, description="The mapping from layer names or regexp expression to ranks"
)
alpha_pattern: Optional[Dict[str, int]] = Field(
{}, description="The mapping from layer names or regexp expression to alphas"
)
# rank_pattern: Optional[Dict[str, int]] = Field(
# {}, description="The mapping from layer names or regexp expression to ranks"
# )
# alpha_pattern: Optional[Dict[str, int]] = Field(
# {}, description="The mapping from layer names or regexp expression to alphas"
# )


# TODO: Get comprehensive Args!
Expand All @@ -161,7 +161,7 @@ class TrainingArgs(BaseModel):
gradient_checkpointing: Optional[bool] = Field(
True, description="Flag to enable gradient checkpointing"
)
# optim: Optional[str] = Field("paged_adamw_32bit", description="Optimizer")
optim: Optional[str] = Field("paged_adamw_32bit", description="Optimizer")
logging_steps: Optional[int] = Field(100, description="Number of logging steps")
learning_rate: Optional[float] = Field(2.0e-4, description="Learning rate")
bf16: Optional[bool] = Field(False, description="Flag to enable bf16")
Expand All @@ -184,8 +184,8 @@ class SftArgs(BaseModel):


class TrainingConfig(BaseModel):
training_args: Optional[TrainingArgs]
sft_args: Optional[SftArgs]
training_args: TrainingArgs
sft_args: SftArgs


# TODO: Get comprehensive Args!
Expand All @@ -207,13 +207,13 @@ class AblationConfig(BaseModel):

class Config(BaseModel):
save_dir: Optional[str] = Field("./experiments", description="Folder to save to")
ablation: Optional[AblationConfig]
ablation: AblationConfig
accelerate: Optional[bool] = Field(
False,
description="set to True if you want to use multi-gpu training; then launch with `accelerate launch --config_file ./accelerate_config toolkit.py`",
)
data: DataConfig
model: ModelConfig
lora: Optional[LoraConfig]
training: Optional[TrainingConfig]
inference: Optional[InferenceConfig]
lora: LoraConfig
training: TrainingConfig
inference: InferenceConfig
2 changes: 1 addition & 1 deletion toolkit/src/utils/save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,5 @@ def save_config(self) -> None:
os.makedirs(self.save_paths.config, exist_ok=True)
model_dict = self.config.model_dump()

with open(os.path.join(self.save_paths.config, "config.yaml"), "w") as file:
with open(os.path.join(self.save_paths.config, "config.yml"), "w") as file:
yaml.dump(model_dict, file)
13 changes: 6 additions & 7 deletions toolkit/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,11 @@ def run_one_experiment(config: Config) -> None:
# Load YAML config
with open(config_path, "r") as file:
config = yaml.safe_load(file)
if config.get("ablation"):
if config["ablation"].get("use_ablate", False):
configs = generate_permutations(config, Config)
else:
configs = [config]

configs = (
generate_permutations(config, Config)
if config.get("ablation", {}).get("use_ablate", False)
else [config]
)
for config in configs:
try:
config = Config(**config)
Expand All @@ -99,7 +98,7 @@ def run_one_experiment(config: Config) -> None:
dir_helper = DirectoryHelper(config_path, config)

# Reload config from saved config
with open(join(dir_helper.config_path.config, "config.yml"), "r") as file:
with open(join(dir_helper.save_paths.config, "config.yml"), "r") as file:
config = yaml.safe_load(file)
config = Config(**config)

Expand Down

0 comments on commit ec3d158

Please sign in to comment.