Skip to content

Commit

Permalink
support eos_token list in turbomind (#3044)
Browse files Browse the repository at this point in the history
* support list of eos_id

* fix wrong size

* fix lint

* fix lint

* fix ut

* support GenerationConfig.eos_token_id is None

* rename

* remove unused

* remove start_id

* move eos_id from model to gen_confg

* move stop/bad words to turbomind::GenerationConfig

* remove tokenizer_info

* update invoke_min_length_penalty

* use array to store stop_ids/bad_ids
  • Loading branch information
irexyc authored Feb 17, 2025
1 parent 3f2c74c commit bfc845a
Show file tree
Hide file tree
Showing 27 changed files with 181 additions and 229 deletions.
12 changes: 12 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ def special_word_token_ids(words):
self.stop_token_ids = list(set(stop_token_ids)) or None
self.bad_token_ids = list(set(bad_token_ids)) or None

def update_from_hf_gen_cfg(self, generation_config, tokenizer_eos_token_id):
"""update the stop_token_ids."""
stop_token_ids = self.stop_token_ids or []
if tokenizer_eos_token_id is not None:
stop_token_ids.append(tokenizer_eos_token_id)
eos_token_id = generation_config.get('eos_token_id')
if eos_token_id is not None:
eos_token_id = {eos_token_id} if isinstance(eos_token_id, int) else set(eos_token_id)
if stop_token_ids:
eos_token_id.update(stop_token_ids)
self.stop_token_ids = list(eos_token_id)

def __post_init__(self):
"""Check input validation."""
assert type(self.n) == int and self.n > 0, 'n is not a positive integer'
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from lmdeploy.model import MODELS, ChatTemplateConfig, best_match_model
from lmdeploy.serve.utils import LogitsMixin
from lmdeploy.tokenizer import DetokenizeState
from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_logger
from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_hf_gen_cfg, get_logger

logger = get_logger('lmdeploy')

Expand Down Expand Up @@ -270,6 +270,8 @@ def __init__(self,
logger.info(f'updated chat_template_onfig={chat_template_config}')

self.tokenizer = Tokenizer(model_path)
self.hf_gen_cfg = get_hf_gen_cfg(model_path)

# build backend engine
if backend == 'turbomind':
self._build_turbomind(model_path=model_path, backend_config=backend_config, **kwargs)
Expand Down Expand Up @@ -634,6 +636,7 @@ async def generate(
else:
gen_config = deepcopy(gen_config)
gen_config.convert_stop_bad_words_to_ids(self.tokenizer)
gen_config.update_from_hf_gen_cfg(self.hf_gen_cfg, self.tokenizer.eos_token_id)
if gen_config.stop_token_ids is None:
gen_config.stop_token_ids = self.stop_words
if not gen_config.do_sample:
Expand Down
2 changes: 0 additions & 2 deletions lmdeploy/turbomind/deploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ class ModelConfig:
inter_size: List[int] = None
norm_eps: float = None
attn_bias: int = 0
start_id: int = None
end_id: int = None
size_per_head: int = 128
group_size: int = 64
weight_type: str = None
Expand Down
5 changes: 0 additions & 5 deletions lmdeploy/turbomind/deploy/source_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,6 @@ def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
self.model_path = model_path
self.tokenizer_path = tokenizer_path

@abstractmethod
def tokenizer_info(self):
"""Read tokenizer info."""
pass

@abstractmethod
def model_info(self) -> Dict:
"""Read model info."""
Expand Down
6 changes: 0 additions & 6 deletions lmdeploy/turbomind/deploy/source_model/deepseek2.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,6 @@ class DeepSeek2Model(LlamaModel):

Reader = DeepSeek2Reader

def tokenizer_info(self):
n_words = self.model_config['vocab_size']
bos_id = self.model_config['bos_token_id']
eos_id = self.model_config['eos_token_id']
return n_words, bos_id, eos_id

def model_info(self):
cfg = self.model_config
info = super().model_info()
Expand Down
10 changes: 0 additions & 10 deletions lmdeploy/turbomind/deploy/source_model/glm4.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
from typing import List

import torch

Expand Down Expand Up @@ -66,15 +65,6 @@ def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
with open(config_path) as f:
self.config = json.load(f)

def tokenizer_info(self):
"""Read tokenizer info."""
n_words = self.config['padded_vocab_size']
bos_id = 0
eos_id = self.config['eos_token_id']
if isinstance(eos_id, List):
eos_id = eos_id[0]
return n_words, bos_id, eos_id

def model_info(self):
"""Read model info."""
config = self.config
Expand Down
12 changes: 0 additions & 12 deletions lmdeploy/turbomind/deploy/source_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch

from lmdeploy.archs import get_model_arch
from lmdeploy.tokenizer import Tokenizer

from ..loader import create_loader
from .base import INPUT_MODELS, BaseInputModel, BaseReader
Expand Down Expand Up @@ -115,17 +114,6 @@ def readers(self):
reader = self.Reader(param, {}, False, self.model_config, policy=self.policy)
yield i, reader

def tokenizer_info(self):
"""Read tokenizer info."""
assert osp.isdir(self.model_path), self.model_path
tk_model = Tokenizer(self.model_path)
n_words = tk_model.vocab_size
bos_id = tk_model.bos_token_id
eos_id = tk_model.eos_token_id
# bos_id may be None
bos_id = bos_id or 0
return n_words, bos_id, eos_id

def model_info(self):
"""Read model info."""
params_path = osp.join(self.model_path, 'config.json')
Expand Down
7 changes: 0 additions & 7 deletions lmdeploy/turbomind/deploy/source_model/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,6 @@ def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
with open(config_path) as f:
self.config = json.load(f)

def tokenizer_info(self):

n_words = 152064
bos_id = 151643
eos_id = 151643
return n_words, bos_id, eos_id

def model_info(self):
config = self.config
num_layer = config['num_hidden_layers']
Expand Down
25 changes: 0 additions & 25 deletions lmdeploy/turbomind/deploy/source_model/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,6 @@ class QwenModel(LlamaModel):

Reader = QwenReader

def tokenizer_info(self):
"""Read tokenizer info."""
n_words = 151851
bos_id = 0
eos_id = 151643
return n_words, bos_id, eos_id

def model_info(self):
"""Read model info."""
params_path = osp.join(self.model_path, 'config.json')
Expand Down Expand Up @@ -105,16 +98,6 @@ class Qwen2Model(LlamaModel):

Reader = LlamaReader

def tokenizer_info(self):
"""set tokenizer info.
Refer to https://huggingface.co/Qwen/Qwen1.5-7B-Chat/blob/main/generation_config.json
""" # noqa: E501
n_words = 152064
bos_id = 151643
eos_id = 151645
return n_words, bos_id, eos_id

def model_info(self):
cfg = super().model_info()
cfg['attn_bias'] = 1
Expand Down Expand Up @@ -159,14 +142,6 @@ class Qwen2MoeModel(LlamaModel):

Reader = Qwen2MoeReader

def tokenizer_info(self):
"""https://huggingface.co/Qwen/Qwen1.5-7B-Chat/blob/main/generation_con
fig.json.""" # noqa: E501
n_words = 152064
bos_id = 151643
eos_id = 151645
return n_words, bos_id, eos_id

def model_info(self):
cfg = self.model_config
info = super().model_info()
Expand Down
9 changes: 2 additions & 7 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,9 @@ def __init__(self, input_model: BaseInputModel, cfg: TurbomindModelConfig, model
self.to_file = True if out_dir else False
self.tm_params = {}

# get `model_info` and `tokenizer_info` at first, which
# will be updated to `self.model_config` and `self.attention_config`
# get `model_info` at first, which will be updated to `self.model_config` and `self.attention_config`
self.input_model_info = self.input_model.model_info()
self.input_model_info = self.single_to_list(self.input_model_info, keys=['inter_size', 'expert_num'])
self.input_model_tokenizer_info = self.input_model.tokenizer_info()
self.permute_qk = self.input_model_info.get('permute_qk', True)
self.update_model_config()
for i, v in enumerate(self.model_config.inter_size):
Expand Down Expand Up @@ -97,11 +95,8 @@ def single_to_list(self, config: dict, keys):

def update_model_config(self):
"""Update `self.model_config` according to the input_model's
`tokenizer_info` and `model_info`"""
_, bos_id, eos_id = self.input_model_tokenizer_info

`model_info`"""
final_cfg = config_to_dict(self.model_config)
final_cfg.update(dict(start_id=bos_id, end_id=eos_id))
final_cfg.update(self.input_model_info)
if 'embedding_size' not in self.input_model_info.keys():
final_cfg.update(embedding_size=self.input_model_info['vocab_size'])
Expand Down
30 changes: 8 additions & 22 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
def _construct_stop_or_bad_words(words: List[int] = None):
if words is None or len(words) == 0:
return None
offsets = range(1, len(words) + 1)
combined = np.array([[words, offsets]]).astype(np.int32)
offsets = list(range(1, len(words) + 1))
combined = [words, offsets]
return combined


Expand Down Expand Up @@ -119,7 +119,6 @@ def __init__(self,
pass

self.session_len = self.config.session_len
self.eos_id = self.tokenizer.eos_token_id

def _create_weight(self, model_comm):
"""Allocate weight buffer, load params if from_workspace."""
Expand Down Expand Up @@ -404,7 +403,6 @@ def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_strea
self.node_id = tm_model.node_id
self.gpu_count = tm_model.gpu_count

self.eos_id = tm_model.eos_id
self.session_len = tm_model.session_len

self.nccl_params = tm_model.nccl_params
Expand Down Expand Up @@ -498,24 +496,6 @@ def prepare_inputs(self,
inputs['input_embeddings'] = input_embeddings
inputs['input_embedding_ranges'] = input_embedding_ranges

bad_words = []
if gen_config.bad_token_ids is not None:
bad_words.extend(gen_config.bad_token_ids)
if gen_config.ignore_eos:
stop_words = None
bad_words.append(self.eos_id)
else:
stop_words = gen_config.stop_token_ids or []
if self.eos_id not in stop_words:
stop_words.append(self.eos_id)
stop_words = _construct_stop_or_bad_words(stop_words)
bad_words = _construct_stop_or_bad_words(bad_words)

if stop_words is not None:
inputs['stop_words_list'] = stop_words
if bad_words is not None:
inputs['bad_words_list'] = bad_words

return inputs, input_len

async def async_cancel(self, session_id: int = None):
Expand Down Expand Up @@ -647,6 +627,12 @@ def _get_generation_config(self, cfg: GenerationConfig):
c.top_p = cfg.top_p
c.min_p = cfg.min_p
c.temperature = cfg.temperature
if cfg.stop_token_ids:
c.eos_ids = cfg.stop_token_ids
if cfg.bad_token_ids:
c.bad_ids = _construct_stop_or_bad_words(cfg.bad_token_ids)
if not cfg.ignore_eos and cfg.stop_token_ids:
c.stop_ids = _construct_stop_or_bad_words(cfg.stop_token_ids)
c.repetition_penalty = cfg.repetition_penalty
if cfg.min_new_tokens:
c.min_new_tokens = cfg.min_new_tokens
Expand Down
9 changes: 9 additions & 0 deletions lmdeploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ def _stop_words(stop_words: List[Union[int, str]], tokenizer: object):
return stop_words


def get_hf_gen_cfg(path: str):
from transformers import GenerationConfig
try:
cfg = GenerationConfig.from_pretrained(path, trust_remote_code=True)
return cfg.to_dict()
except OSError:
return {}


def get_model(pretrained_model_name_or_path: str, download_dir: str = None, revision: str = None, token: str = None):
"""Get model from huggingface, modelscope or openmind_hub."""
import os
Expand Down
22 changes: 22 additions & 0 deletions src/turbomind/engine/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

#pragma once

#include <array>
#include <atomic>
#include <cstdint>
#include <functional>
#include <iterator>
#include <memory>
#include <ostream>

Expand All @@ -16,6 +18,11 @@ struct GenerationConfig {
int max_new_tokens = 0;
int min_new_tokens = 0;

std::vector<int> eos_ids; // only support single token id

std::array<std::vector<int>, 2> stop_ids; // (token_id, offset)
std::array<std::vector<int>, 2> bad_ids;

int top_k = 1;
float top_p = 0.f;
float min_p = 0.f;
Expand All @@ -37,11 +44,26 @@ struct GenerationConfig {
int output_logits = 0;
};

template<typename T>
inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& vec)
{
os << "[";
std::copy(vec.begin(), vec.end(), std::ostream_iterator<T>(os, ", "));
if (!vec.empty()) {
os.seekp(-2, std::ios_base::end);
}
os << "]";
return os;
}

inline std::ostream& operator<<(std::ostream& os, const GenerationConfig& c)
{
os << "GenerationConfig { ";
os << "max_new_tokens=" << c.max_new_tokens;
os << ", min_new_tokens=" << c.min_new_tokens;
os << ", eos_ids=" << c.eos_ids;
os << ", stop_ids=[" << c.stop_ids[0] << ", " << c.stop_ids[1] << "]";
os << ", bad_ids=[" << c.bad_ids[0] << ", " << c.bad_ids[1] << "]";
os << ", top_p=" << c.top_p;
os << ", top_k=" << c.top_k;
os << ", min_p=" << c.min_p;
Expand Down
Loading

0 comments on commit bfc845a

Please sign in to comment.