-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add llama * reformat & update requirements * refine * refine * reformat * add pipeline * add alpaca and test * add sft * update rotary embedding * refine * support llama trainer * modify rotary embed * test sft ckpt * support EleutherAI * refine * refine and add readme * fix inference in float16 * fix mask dtype * add activation_checkpoint * refine by comment * refine
- Loading branch information
1 parent
b97fc9f
commit ddb5ea1
Showing
16 changed files
with
1,692 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from omegaconf import DictConfig, OmegaConf | ||
|
||
from libai.config import LazyCall | ||
from projects.Llama.llama import LlamaForCausalLM | ||
from projects.Llama.tokenizer import LlamaTokenizer | ||
from configs.common.train import train | ||
|
||
|
||
cfg = dict( | ||
# Model | ||
hidden_act="silu", | ||
hidden_size=4096, | ||
initializer_range=0.02, | ||
intermediate_size=11008, | ||
max_position_embeddings=4096, | ||
num_attention_heads=32, | ||
hidden_layers=32, | ||
num_key_value_heads=32, | ||
pretraining_tp=1, | ||
rms_norm_eps=1e-05, | ||
rope_scaling=None, | ||
tie_word_embeddings=False, | ||
vocab_size=32000, | ||
use_scaled_init_for_output_weights=False, | ||
scale_mask_softmax_fusion=False, | ||
amp_enabled=True, | ||
# Inference | ||
is_encoder_decoder=False, | ||
max_length=256, | ||
min_length=0, | ||
do_sample=False, | ||
early_stopping=False, | ||
num_beams=1, | ||
num_beam_groups=1, | ||
diversity_penalty=0.0, | ||
temperature=0.9, | ||
top_k=50, | ||
top_p=0.6, | ||
typical_p=1.0, | ||
repetition_penalty=1.0, | ||
length_penalty=1.0, | ||
no_repeat_ngram_size=0, | ||
encoder_no_repeat_ngram_size=0, | ||
num_return_sequences=1, | ||
chunk_size_feed_forward=0, | ||
output_scores=False, | ||
use_cache=True, | ||
bos_token_id=1, | ||
eos_token_id=2, | ||
pad_token_id=0, | ||
# train | ||
pretrained_model_path="meta-llama/Llama-2-7b-hf", | ||
) | ||
|
||
cfg = DictConfig(cfg) | ||
|
||
model = LazyCall(LlamaForCausalLM)(cfg=cfg) | ||
tokenization = OmegaConf.create() | ||
tokenization.make_vocab_size_divisible_by = 1 | ||
tokenization.tokenizer = LazyCall(LlamaTokenizer)( | ||
pretrained_model_path="Llama-2-7b-hf/tokenizer.model" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import os | ||
from omegaconf import OmegaConf | ||
|
||
from libai.config import LazyCall | ||
from libai.evaluation import PPLEvaluator | ||
from libai.scheduler import WarmupExponentialLR | ||
from libai.data.build import build_nlp_test_loader, build_nlp_train_loader | ||
|
||
from configs.common.train import train | ||
from configs.common.models.graph import graph | ||
from configs.common.optim import optim | ||
|
||
from projects.Llama.configs.llama_config import cfg | ||
from projects.Llama.dataset import AlpacaDataset | ||
from projects.Llama.tokenizer import LlamaTokenizer | ||
from projects.Llama.llama import LlamaForCausalLM | ||
|
||
|
||
# Hyperparameters | ||
weight_decay = 0.1 | ||
learning_rate = 2e-5 | ||
max_input_length = 1350 | ||
dataset_path = "alpaca_data" | ||
pretrained_model_path = "meta-llama/Llama-2-7b-hf" | ||
|
||
# graph & optim | ||
graph["enabled"] = True | ||
optim.update( | ||
dict( | ||
lr=learning_rate, | ||
weight_decay=weight_decay, | ||
) | ||
) | ||
|
||
# tokenize | ||
tokenization = OmegaConf.create() | ||
tokenization.make_vocab_size_divisible_by = 1 | ||
tokenization.tokenizer = LazyCall(LlamaTokenizer)( | ||
pretrained_model_path=os.path.join(pretrained_model_path, "tokenizer.model") | ||
) | ||
|
||
# model | ||
model = LazyCall(LlamaForCausalLM)(cfg=cfg) | ||
|
||
# datasets | ||
dataloader = OmegaConf.create() | ||
dataloader.train = LazyCall(build_nlp_train_loader)( | ||
dataset=[ | ||
LazyCall(AlpacaDataset)( | ||
path=os.path.join(dataset_path, "train"), | ||
tokenizer=tokenization.tokenizer, | ||
max_len=max_input_length, | ||
) | ||
], | ||
) | ||
dataloader.test = [ | ||
LazyCall(build_nlp_test_loader)( | ||
dataset=LazyCall(AlpacaDataset)( | ||
path=os.path.join(dataset_path, "test"), | ||
tokenizer=tokenization.tokenizer, | ||
max_len=max_input_length, | ||
), | ||
), | ||
] | ||
|
||
|
||
train.update( | ||
dict( | ||
output_dir="./sft_result", | ||
train_micro_batch_size=2, | ||
test_micro_batch_size=1, | ||
train_epoch=5, | ||
train_iter=1, | ||
log_period=10, | ||
warmup_ratio=2 / 5, | ||
num_accumulation_steps=8, | ||
rdma_enabled=True, | ||
amp=dict(enabled=True), | ||
activation_checkpoint=dict(enabled=True), | ||
checkpointer=dict( | ||
period=100, | ||
max_to_keep=20, | ||
), | ||
dist=dict( | ||
data_parallel_size=2, | ||
tensor_parallel_size=1, | ||
pipeline_parallel_size=4, | ||
pipeline_num_layers=cfg.hidden_layers, | ||
), | ||
evaluation=dict( | ||
enabled=True, | ||
evaluator=LazyCall(PPLEvaluator)(), | ||
eval_period=100, | ||
eval_iter=1e5, | ||
), | ||
scheduler=LazyCall(WarmupExponentialLR)( | ||
warmup_factor=0.0, | ||
gamma=1.0, | ||
warmup_method="linear", | ||
), | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# coding=utf-8 | ||
# Copyright 2021 The OneFlow Authors. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import random | ||
|
||
import oneflow as flow | ||
from oneflow.utils.data import Dataset | ||
|
||
from libai.data.structures import DistTensorData, Instance | ||
|
||
|
||
def pad_right(data, pad_id=0, max_len=1350): | ||
n = max_len - data.shape[0] | ||
return flow.cat((data, flow.full((n,), pad_id, dtype=data.dtype))) | ||
|
||
|
||
class AlpacaDataset(Dataset): | ||
def __init__(self, path, tokenizer, max_len=1350): | ||
self.data = flow.load(path) | ||
random.shuffle(self.data) | ||
self.tokenizer = tokenizer | ||
self.max_len = max_len | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, index): | ||
input_ids = pad_right(self.data[index]["input_ids"], pad_id=0, max_len=self.max_len) | ||
labels = pad_right(self.data[index]["labels"], pad_id=-1, max_len=self.max_len) | ||
|
||
return Instance( | ||
input_ids=DistTensorData(input_ids), | ||
labels=DistTensorData(labels), | ||
) |
Oops, something went wrong.