Skip to content

Commit 3aefdad

Browse files
committed
release v0.9.0 (real)
Former-commit-id: 90d6df6
1 parent 561ae4d commit 3aefdad

File tree

7 files changed

+45
-53
lines changed

7 files changed

+45
-53
lines changed

.env.local

+3-1
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@ WANDB_DISABLED=
2424
WANDB_PROJECT=huggingface
2525
WANDB_API_KEY=
2626
# gradio ui
27-
GRADIO_SHARE=0
27+
GRADIO_SHARE=False
2828
GRADIO_SERVER_NAME=0.0.0.0
2929
GRADIO_SERVER_PORT=
3030
GRADIO_ROOT_PATH=
31+
# setup
32+
ENABLE_SHORT_CONSOLE=1
3133
# reserved (do not use)
3234
LLAMABOARD_ENABLED=
3335
LLAMABOARD_WORKDIR=

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
275275
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
276276
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
277277
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
278-
- [Pokemon-gpt4o-captions](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
278+
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
279279
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
280280
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
281281
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)

README_zh.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
276276
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
277277
- [Magpie-ultra-v0.1 (en)](https://huggingface.co/datasets/argilla/magpie-ultra-v0.1)
278278
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
279-
- [Pokemon-gpt4o-captions](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
279+
- [Pokemon-gpt4o-captions (en&zh)](https://huggingface.co/datasets/jugg1024/pokemon-gpt4o-captions)
280280
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
281281
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
282282
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)

scripts/cal_mfu.py

+3
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def calculate_mfu(
131131
"dataset": "c4_demo",
132132
"cutoff_len": seq_length,
133133
"output_dir": os.path.join("saves", "test_mfu"),
134+
"logging_strategy": "no",
135+
"save_strategy": "no",
136+
"save_only_model": True,
134137
"overwrite_output_dir": True,
135138
"per_device_train_batch_size": batch_size,
136139
"max_steps": num_steps,

setup.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,34 @@
1414

1515
import os
1616
import re
17+
from typing import List
1718

1819
from setuptools import find_packages, setup
1920

2021

21-
def get_version():
22+
def get_version() -> str:
2223
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
2324
file_content = f.read()
2425
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
2526
(version,) = re.findall(pattern, file_content)
2627
return version
2728

2829

29-
def get_requires():
30+
def get_requires() -> List[str]:
3031
with open("requirements.txt", "r", encoding="utf-8") as f:
3132
file_content = f.read()
3233
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
3334
return lines
3435

3536

37+
def get_console_scripts() -> List[str]:
38+
console_scripts = ["llamafactory-cli = llamafactory.cli:main"]
39+
if os.environ.get("ENABLE_SHORT_CONSOLE", "1").lower() in ["true", "1"]:
40+
console_scripts.append("lmf = llamafactory.cli:main")
41+
42+
return console_scripts
43+
44+
3645
extra_require = {
3746
"torch": ["torch>=1.13.1"],
3847
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
@@ -72,7 +81,7 @@ def main():
7281
python_requires=">=3.8.0",
7382
install_requires=get_requires(),
7483
extras_require=extra_require,
75-
entry_points={"console_scripts": ["llamafactory-cli = llamafactory.cli:main"]},
84+
entry_points={"console_scripts": get_console_scripts()},
7685
classifiers=[
7786
"Development Status :: 4 - Beta",
7887
"Intended Audience :: Developers",

src/llamafactory/extras/constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ def register_model_group(
829829

830830
register_model_group(
831831
models={
832-
"MiniCPM3-4B": {
832+
"MiniCPM3-4B-Chat": {
833833
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
834834
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
835835
},

src/llamafactory/train/callbacks.py

+24-46
Original file line numberDiff line numberDiff line change
@@ -96,38 +96,45 @@ def fix_valuehead_checkpoint(
9696

9797

9898
class FixValueHeadModelCallback(TrainerCallback):
99+
r"""
100+
A callback for fixing the checkpoint for valuehead models.
101+
"""
102+
99103
@override
100104
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
101105
r"""
102106
Event called after a checkpoint save.
103107
"""
104108
if args.should_save:
109+
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
105110
fix_valuehead_checkpoint(
106-
model=kwargs.pop("model"),
107-
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
108-
safe_serialization=args.save_safetensors,
111+
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
109112
)
110113

111114

112115
class SaveProcessorCallback(TrainerCallback):
116+
r"""
117+
A callback for saving the processor.
118+
"""
119+
113120
def __init__(self, processor: "ProcessorMixin") -> None:
114-
r"""
115-
Initializes a callback for saving the processor.
116-
"""
117121
self.processor = processor
118122

123+
@override
124+
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
125+
if args.should_save:
126+
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
127+
getattr(self.processor, "image_processor").save_pretrained(output_dir)
128+
119129
@override
120130
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
121-
r"""
122-
Event called at the end of training.
123-
"""
124131
if args.should_save:
125132
getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
126133

127134

128135
class PissaConvertCallback(TrainerCallback):
129136
r"""
130-
Initializes a callback for converting the PiSSA adapter to a normal one.
137+
A callback for converting the PiSSA adapter to a normal one.
131138
"""
132139

133140
@override
@@ -147,9 +154,6 @@ def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", contr
147154

148155
@override
149156
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
150-
r"""
151-
Event called at the end of training.
152-
"""
153157
if args.should_save:
154158
model = kwargs.pop("model")
155159
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
@@ -177,21 +181,22 @@ def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control
177181

178182

179183
class LogCallback(TrainerCallback):
184+
r"""
185+
A callback for logging training and evaluation status.
186+
"""
187+
180188
def __init__(self) -> None:
181-
r"""
182-
Initializes a callback for logging training and evaluation status.
183-
"""
184-
""" Progress """
189+
# Progress
185190
self.start_time = 0
186191
self.cur_steps = 0
187192
self.max_steps = 0
188193
self.elapsed_time = ""
189194
self.remaining_time = ""
190195
self.thread_pool: Optional["ThreadPoolExecutor"] = None
191-
""" Status """
196+
# Status
192197
self.aborted = False
193198
self.do_train = False
194-
""" Web UI """
199+
# Web UI
195200
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
196201
if self.webui_mode:
197202
signal.signal(signal.SIGABRT, self._set_abort)
@@ -233,9 +238,6 @@ def _close_thread_pool(self) -> None:
233238

234239
@override
235240
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
236-
r"""
237-
Event called at the end of the initialization of the `Trainer`.
238-
"""
239241
if (
240242
args.should_save
241243
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
@@ -246,60 +248,39 @@ def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control:
246248

247249
@override
248250
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
249-
r"""
250-
Event called at the beginning of training.
251-
"""
252251
if args.should_save:
253252
self.do_train = True
254253
self._reset(max_steps=state.max_steps)
255254
self._create_thread_pool(output_dir=args.output_dir)
256255

257256
@override
258257
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
259-
r"""
260-
Event called at the end of training.
261-
"""
262258
self._close_thread_pool()
263259

264260
@override
265261
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
266-
r"""
267-
Event called at the end of an substep during gradient accumulation.
268-
"""
269262
if self.aborted:
270263
control.should_epoch_stop = True
271264
control.should_training_stop = True
272265

273266
@override
274267
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
275-
r"""
276-
Event called at the end of a training step.
277-
"""
278268
if self.aborted:
279269
control.should_epoch_stop = True
280270
control.should_training_stop = True
281271

282272
@override
283273
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
284-
r"""
285-
Event called after an evaluation phase.
286-
"""
287274
if not self.do_train:
288275
self._close_thread_pool()
289276

290277
@override
291278
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
292-
r"""
293-
Event called after a successful prediction.
294-
"""
295279
if not self.do_train:
296280
self._close_thread_pool()
297281

298282
@override
299283
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
300-
r"""
301-
Event called after logging the last logs.
302-
"""
303284
if not args.should_save:
304285
return
305286

@@ -342,9 +323,6 @@ def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "Tra
342323
def on_prediction_step(
343324
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
344325
):
345-
r"""
346-
Event called after a prediction step.
347-
"""
348326
if self.do_train:
349327
return
350328

0 commit comments

Comments
 (0)