Skip to content

Commit

Permalink
Llama2 (#524)
Browse files Browse the repository at this point in the history
* add llama

* reformat & update requirements

* refine

* refine

* reformat

* add pipeline

* add alpaca and test

* add sft

* update rotary embedding

* refine

* support llama trainer

* modify rotary embed

* test sft ckpt

* support EleutherAI

* refine

* refine and add readme

* fix inference in float16

* fix mask dtype

* add activation_checkpoint

* refine by comment

* refine
  • Loading branch information
xiezipeng-ML authored Dec 18, 2023
1 parent b97fc9f commit ddb5ea1
Show file tree
Hide file tree
Showing 16 changed files with 1,692 additions and 18 deletions.
11 changes: 8 additions & 3 deletions libai/inference/generator/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def _prepare_attention_mask_for_generation(
pad_token_id: Optional[int],
eos_token_id: Optional[int],
):
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [flow.int64, flow.long]
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [
flow.int64,
flow.long,
]
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id)
Expand Down Expand Up @@ -502,7 +505,7 @@ def greedy_search(
next_tokens = next_tokens.to_global(placement=input_ids.placement)
unfinished_sequences = unfinished_sequences.to_global(
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
placement=input_ids.placement,
)

if eos_token_id is not None:
Expand Down Expand Up @@ -987,7 +990,9 @@ def generate(

# 8. Prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria
max_length=max_length,
max_time=max_time,
stopping_criteria=stopping_criteria,
)

# 9. Go into different generation modes
Expand Down
37 changes: 24 additions & 13 deletions libai/models/utils/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ def _convert_tensor(self, tensor):
Returns:
flow.Tensor: The target tensor.
"""
tensor = tensor.float()
return flow.Tensor(tensor.detach().cpu().numpy())

def _convert_tensors(self, torch_state_dict):
Expand Down Expand Up @@ -465,8 +464,15 @@ def _load_torch_state_dict(self, state_dict_file):
raise ImportError("Load torch state dict need torch.")

# load pytorch_model.bin
state_dict = torch.load(state_dict_file, map_location="cpu")
return state_dict
if isinstance(state_dict_file, str):
return torch.load(state_dict_file, map_location="cpu")

if isinstance(state_dict_file, list):
merged_state_dict = {}
for file in state_dict_file:
state_dict = torch.load(file, map_location="cpu")
merged_state_dict.update(state_dict)
return merged_state_dict

def _update_cfg(self, keys_libai, value_target):
"""Update the libai_cfg according to target_cfg.
Expand All @@ -491,11 +497,12 @@ def _update_cfg_log(self):
f"changed libai model cfg {temp_key} : "
f"{self.origin_libai_cfg[key]} -> {self.libai_cfg[key]} "
)
logger.warning(
"The following model configurations has been modified according "
"to `config.json` or kwargs: \n"
f"{self.changed_keys} \n"
)
if len(self.changed_keys) > 0:
logger.warning(
"The following model configurations has been modified according "
"to `config.json` or kwargs: \n"
f"{self.changed_keys} \n"
)

if dist.get_pipeline_parallel_size() > 1:
logger.warning(
Expand Down Expand Up @@ -528,11 +535,15 @@ def load(self):
if dist.is_main_process():
if os.path.isdir(self.pretrained_model_path):
# state_dict file pytorch
if os.path.isfile(os.path.join(self.pretrained_model_path, WEIGHTS_NAME_PT)):
model_file = os.path.join(self.pretrained_model_path, WEIGHTS_NAME_PT)
else:
model_files = [
os.path.join(self.pretrained_model_path, file)
for file in os.listdir(self.pretrained_model_path)
if file.endswith(".bin")
]

if len(model_files) == 0:
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME_PT} found"
f"Error: no file named endswith '.bin' found"
f"in directory {self.pretrained_model_path}."
)

Expand All @@ -554,7 +565,7 @@ def load(self):
raise EnvironmentError(f"{self.pretrained_model_path} is not a directory.")

logger.info("loading torch model...")
torch_state_dict = self._load_torch_state_dict(model_file)
torch_state_dict = self._load_torch_state_dict(model_files)
torch_state_dict = self._fix_key(torch_state_dict)
logger.info("transfering torch model into oneflow model...")
flow_state_dict = self._convert_tensors(torch_state_dict)
Expand Down
2 changes: 1 addition & 1 deletion libai/tokenizer/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def build_tokenizer(cfg):
"""Initialize tokenizer."""
tokenizer = instantiate(cfg.tokenizer)

if cfg.append_eod and tokenizer.eod_token is None:
if cfg.get("append_eod", None) and tokenizer.eod_token is None:
if tokenizer.eos_token is not None:
tokenizer.eod_token = tokenizer.eos_token
else:
Expand Down
62 changes: 62 additions & 0 deletions projects/Llama/configs/llama_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from omegaconf import DictConfig, OmegaConf

from libai.config import LazyCall
from projects.Llama.llama import LlamaForCausalLM
from projects.Llama.tokenizer import LlamaTokenizer
from configs.common.train import train


cfg = dict(
# Model
hidden_act="silu",
hidden_size=4096,
initializer_range=0.02,
intermediate_size=11008,
max_position_embeddings=4096,
num_attention_heads=32,
hidden_layers=32,
num_key_value_heads=32,
pretraining_tp=1,
rms_norm_eps=1e-05,
rope_scaling=None,
tie_word_embeddings=False,
vocab_size=32000,
use_scaled_init_for_output_weights=False,
scale_mask_softmax_fusion=False,
amp_enabled=True,
# Inference
is_encoder_decoder=False,
max_length=256,
min_length=0,
do_sample=False,
early_stopping=False,
num_beams=1,
num_beam_groups=1,
diversity_penalty=0.0,
temperature=0.9,
top_k=50,
top_p=0.6,
typical_p=1.0,
repetition_penalty=1.0,
length_penalty=1.0,
no_repeat_ngram_size=0,
encoder_no_repeat_ngram_size=0,
num_return_sequences=1,
chunk_size_feed_forward=0,
output_scores=False,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
pad_token_id=0,
# train
pretrained_model_path="meta-llama/Llama-2-7b-hf",
)

cfg = DictConfig(cfg)

model = LazyCall(LlamaForCausalLM)(cfg=cfg)
tokenization = OmegaConf.create()
tokenization.make_vocab_size_divisible_by = 1
tokenization.tokenizer = LazyCall(LlamaTokenizer)(
pretrained_model_path="Llama-2-7b-hf/tokenizer.model"
)
102 changes: 102 additions & 0 deletions projects/Llama/configs/llama_sft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
from omegaconf import OmegaConf

from libai.config import LazyCall
from libai.evaluation import PPLEvaluator
from libai.scheduler import WarmupExponentialLR
from libai.data.build import build_nlp_test_loader, build_nlp_train_loader

from configs.common.train import train
from configs.common.models.graph import graph
from configs.common.optim import optim

from projects.Llama.configs.llama_config import cfg
from projects.Llama.dataset import AlpacaDataset
from projects.Llama.tokenizer import LlamaTokenizer
from projects.Llama.llama import LlamaForCausalLM


# Hyperparameters
weight_decay = 0.1
learning_rate = 2e-5
max_input_length = 1350
dataset_path = "alpaca_data"
pretrained_model_path = "meta-llama/Llama-2-7b-hf"

# graph & optim
graph["enabled"] = True
optim.update(
dict(
lr=learning_rate,
weight_decay=weight_decay,
)
)

# tokenize
tokenization = OmegaConf.create()
tokenization.make_vocab_size_divisible_by = 1
tokenization.tokenizer = LazyCall(LlamaTokenizer)(
pretrained_model_path=os.path.join(pretrained_model_path, "tokenizer.model")
)

# model
model = LazyCall(LlamaForCausalLM)(cfg=cfg)

# datasets
dataloader = OmegaConf.create()
dataloader.train = LazyCall(build_nlp_train_loader)(
dataset=[
LazyCall(AlpacaDataset)(
path=os.path.join(dataset_path, "train"),
tokenizer=tokenization.tokenizer,
max_len=max_input_length,
)
],
)
dataloader.test = [
LazyCall(build_nlp_test_loader)(
dataset=LazyCall(AlpacaDataset)(
path=os.path.join(dataset_path, "test"),
tokenizer=tokenization.tokenizer,
max_len=max_input_length,
),
),
]


train.update(
dict(
output_dir="./sft_result",
train_micro_batch_size=2,
test_micro_batch_size=1,
train_epoch=5,
train_iter=1,
log_period=10,
warmup_ratio=2 / 5,
num_accumulation_steps=8,
rdma_enabled=True,
amp=dict(enabled=True),
activation_checkpoint=dict(enabled=True),
checkpointer=dict(
period=100,
max_to_keep=20,
),
dist=dict(
data_parallel_size=2,
tensor_parallel_size=1,
pipeline_parallel_size=4,
pipeline_num_layers=cfg.hidden_layers,
),
evaluation=dict(
enabled=True,
evaluator=LazyCall(PPLEvaluator)(),
eval_period=100,
eval_iter=1e5,
),
scheduler=LazyCall(WarmupExponentialLR)(
warmup_factor=0.0,
gamma=1.0,
warmup_method="linear",
),
)
)
46 changes: 46 additions & 0 deletions projects/Llama/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random

import oneflow as flow
from oneflow.utils.data import Dataset

from libai.data.structures import DistTensorData, Instance


def pad_right(data, pad_id=0, max_len=1350):
n = max_len - data.shape[0]
return flow.cat((data, flow.full((n,), pad_id, dtype=data.dtype)))


class AlpacaDataset(Dataset):
def __init__(self, path, tokenizer, max_len=1350):
self.data = flow.load(path)
random.shuffle(self.data)
self.tokenizer = tokenizer
self.max_len = max_len

def __len__(self):
return len(self.data)

def __getitem__(self, index):
input_ids = pad_right(self.data[index]["input_ids"], pad_id=0, max_len=self.max_len)
labels = pad_right(self.data[index]["labels"], pad_id=-1, max_len=self.max_len)

return Instance(
input_ids=DistTensorData(input_ids),
labels=DistTensorData(labels),
)
Loading

0 comments on commit ddb5ea1

Please sign in to comment.