Skip to content

Commit 9056c46

Browse files
authored
Enable quantizing local checkpoints in model release script (#2859)
Enable quantizing local checkpoints in model release script Summary: For torchao model release scripts, previously we only support quantizing models downloaded from hf directly (with a model id), this PR turns it off by default and allows users to quantize a local checkpoint Test Plan: cd .github/scripts/torchao_model_releases/ ./release.sh --model_id $LOCAL_MODEL_PATH --quants FP8 Reviewers: Subscribers: Tasks: Tags:
1 parent d321a2c commit 9056c46

File tree

3 files changed

+24
-9
lines changed

3 files changed

+24
-9
lines changed

.github/scripts/torchao_model_releases/quantize_and_upload.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def _untie_weights_and_save_locally(model_id):
568568
"""
569569

570570

571-
def quantize_and_upload(model_id, quant):
571+
def quantize_and_upload(model_id, quant, push_to_hub):
572572
_int8_int4_linear_config = Int8DynamicActivationIntxWeightConfig(
573573
weight_dtype=torch.int4,
574574
weight_granularity=PerGroup(32),
@@ -657,9 +657,13 @@ def quantize_and_upload(model_id, quant):
657657
card = ModelCard(content)
658658

659659
# Push to hub
660-
quantized_model.push_to_hub(quantized_model_id, safe_serialization=False)
661-
tokenizer.push_to_hub(quantized_model_id)
662-
card.push_to_hub(quantized_model_id)
660+
if push_to_hub:
661+
quantized_model.push_to_hub(quantized_model_id, safe_serialization=False)
662+
tokenizer.push_to_hub(quantized_model_id)
663+
card.push_to_hub(quantized_model_id)
664+
else:
665+
quantized_model.save_pretrained(quantized_model_id, safe_serialization=False)
666+
tokenizer.save_pretrained(quantized_model_id)
663667

664668
# Manual Testing
665669
prompt = "Hey, are you conscious? Can you talk to me?"
@@ -700,5 +704,11 @@ def quantize_and_upload(model_id, quant):
700704
type=str,
701705
help="Quantization method. Options are FP8, INT4, INT8_INT4, AWQ-INT4",
702706
)
707+
parser.add_argument(
708+
"--push_to_hub",
709+
action="store_true",
710+
default=False,
711+
help="Flag to indicate whether push to huggingface hub or not",
712+
)
703713
args = parser.parse_args()
704-
quantize_and_upload(args.model_id, args.quant)
714+
quantize_and_upload(args.model_id, args.quant, args.push_to_hub)

.github/scripts/torchao_model_releases/release.sh

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
# Default quantization options
1616
default_quants=("FP8" "INT4" "INT8-INT4")
17+
push_to_hub=""
1718
# Parse arguments
1819
while [[ $# -gt 0 ]]; do
1920
case "$1" in
@@ -29,6 +30,10 @@ while [[ $# -gt 0 ]]; do
2930
shift
3031
done
3132
;;
33+
--push_to_hub)
34+
push_to_hub="--push_to_hub"
35+
shift
36+
;;
3237
*)
3338
echo "Unknown option: $1"
3439
exit 1
@@ -38,14 +43,14 @@ done
3843
# Use default quants if none specified
3944
if [[ -z "$model_id" ]]; then
4045
echo "Error: --model_id is required"
41-
echo "Usage: $0 --model_id <model_id> [--quants <quant1> [quant2 ...]]"
46+
echo "Usage: $0 --model_id <model_id> [--quants <quant1> [quant2 ...]] [--push_to_hub]"
4247
exit 1
4348
fi
4449
if [[ ${#quants[@]} -eq 0 ]]; then
4550
quants=("${default_quants[@]}")
4651
fi
4752
# Run the python command for each quantization option
4853
for quant in "${quants[@]}"; do
49-
echo "Running: python quantize_and_upload.py --model_id $model_id --quant $quant"
50-
python quantize_and_upload.py --model_id "$model_id" --quant "$quant"
54+
echo "Running: python quantize_and_upload.py --model_id $model_id --quant $quant $push_to_hub"
55+
python quantize_and_upload.py --model_id "$model_id" --quant "$quant" $push_to_hub
5156
done

.github/workflows/release_model.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,4 @@ jobs:
4343
pip install .
4444
HF_MODEL_ID=${{ github.event.inputs.hf_model_id }}
4545
cd .github/scripts/torchao_model_releases
46-
./release.sh --model_id $HF_MODEL_ID
46+
./release.sh --model_id $HF_MODEL_ID --push_to_hub

0 commit comments

Comments
 (0)