-
Notifications
You must be signed in to change notification settings - Fork 156
[MoE] Cleanup MoE examples #1576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
kylesayrs
wants to merge
20
commits into
main
Choose a base branch
from
kylesayrs/cleanup-moe-examples
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+307
−282
Draft
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
b30eade
deepseekv3
kylesayrs a957f2f
remove dreg
kylesayrs 2fd2a25
reformat example
kylesayrs b8b217c
wip: clean up moe examples
kylesayrs 43bc91d
remove deepseek2.5 for now
kylesayrs 7d8ed36
update readme
kylesayrs b7273a9
infer model device with optional override
kylesayrs afebe2e
handle nullable dataset_args
kylesayrs ab3aa3e
update docstrings, comments
kylesayrs e9e30c3
rename files, update examples tests
kylesayrs 6bf5acb
rebase on main
kylesayrs e77a31b
clean examples
kylesayrs 366ac25
revert examples changes
kylesayrs c44da34
revert extra examples
kylesayrs 2db2789
revert examples changes
kylesayrs 0dc2381
remove extra examples
kylesayrs b70aba7
revert examples tests changes
kylesayrs 5e5657b
Revert "revert extra examples"
kylesayrs 735c317
Merge branch 'kylesayrs/deepseek-v3' into kylesayrs/cleanup-moe-examples
kylesayrs 4812350
clean up examples
kylesayrs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from datasets import load_dataset | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from llmcompressor.modeling import prepare_for_quantization | ||
from llmcompressor.modifiers.quantization import GPTQModifier | ||
from llmcompressor.transformers import oneshot | ||
from llmcompressor.utils import dispatch_for_generation | ||
|
||
# Select model and load it. | ||
# For DeepSeekv3, we require a full precision model in order to properly calibrate | ||
# `DeepSeek-V3-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16 | ||
model_id = "RedHatAI/DeepSeek-V3-BF16" | ||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
model = prepare_for_quantization(model) | ||
|
||
# Select calibration dataset. | ||
DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
DATASET_SPLIT = "train_sft" | ||
|
||
# Select number of samples. 512 samples is a good place to start. | ||
# Increasing the number of samples can improve accuracy. | ||
NUM_CALIBRATION_SAMPLES = 512 | ||
MAX_SEQUENCE_LENGTH = 2048 | ||
|
||
# Load dataset and preprocess. | ||
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | ||
ds = ds.shuffle(seed=42) | ||
|
||
|
||
def preprocess(example): | ||
return { | ||
"text": tokenizer.apply_chat_template( | ||
example["messages"], | ||
tokenize=False, | ||
) | ||
} | ||
|
||
|
||
ds = ds.map(preprocess) | ||
|
||
|
||
# Tokenize inputs. | ||
def tokenize(sample): | ||
return tokenizer( | ||
sample["text"], | ||
padding=False, | ||
max_length=MAX_SEQUENCE_LENGTH, | ||
truncation=True, | ||
add_special_tokens=False, | ||
) | ||
|
||
|
||
ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
||
# Configure the quantization algorithm to run. | ||
# since the MoE gate layers are sensitive to quantization, we add them to the ignore | ||
# list so they remain at full precision | ||
recipe = GPTQModifier( | ||
targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"] | ||
) | ||
|
||
# Apply algorithms. | ||
# due to the large size of DeepSeekV3, we specify sequential targets such that | ||
# only one MLP is loaded into GPU memory at a time | ||
oneshot( | ||
model=model, | ||
dataset=ds, | ||
recipe=recipe, | ||
max_seq_length=MAX_SEQUENCE_LENGTH, | ||
num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
sequential_targets=["DeepseekV3Attention", "DeepseekV3MLP"], | ||
) | ||
|
||
# Confirm generations of the quantized model look sane. | ||
print("\n\n") | ||
print("========== SAMPLE GENERATION ==============") | ||
dispatch_for_generation(model) | ||
sample = tokenizer("Hello my name is", return_tensors="pt") | ||
sample = {key: value.to("cuda") for key, value in sample.items()} | ||
output = model.generate(**sample, max_new_tokens=100) | ||
print(tokenizer.decode(output[0])) | ||
print("==========================================\n\n") | ||
|
||
# Save to disk compressed. | ||
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" | ||
model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
tokenizer.save_pretrained(SAVE_DIR) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.