Skip to content

Commit

Permalink
Merge pull request #133 from ssbuild/dev
Browse files Browse the repository at this point in the history
support seed
  • Loading branch information
ssbuild authored Nov 20, 2023
2 parents 6918319 + 87019d0 commit d93761a
Show file tree
Hide file tree
Showing 27 changed files with 193 additions and 166 deletions.
1 change: 1 addition & 0 deletions README.MD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

## update information
```text
11-20 support seed for generator sample
11-06 fix pydantic 2 and support api_keys in config
11-04 support yi aigc-zoo>=0.2.7.post2 , 支持 pydantic >= 2
11-01 support bluelm aigc-zoo>=0.2.7.post1
Expand Down
13 changes: 6 additions & 7 deletions serving/model_handler/baichuan2_13b/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers import HfArgumentParser, BitsAndBytesConfig, GenerationConfig
from aigc_zoo.model_zoo.baichuan.baichuan2_13b.llm_model import MyTransformer,BaichuanConfig,BaichuanTokenizer,\
MyBaichuanForCausalLM,PetlArguments,PetlModel
from serving.model_handler.base import EngineAPI_Base,CompletionResult, CompletionResult,LoraModelState, load_lora_config, GenerateProcess,WorkMode
from serving.model_handler.base import EngineAPI_Base,CompletionResult, CompletionResult,LoraModelState, load_lora_config, GenArgs,WorkMode
from serving.prompt import *


Expand Down Expand Up @@ -123,12 +123,11 @@ def get_default_gen_args(self):
return default_kwargs

def chat_stream(self,messages: List[Dict], **kwargs):
args_process = GenerateProcess(self,is_stream=True)
args_process.preprocess(kwargs)
args_process = GenArgs(kwargs, self, is_stream=True)
chunk = args_process.chunk
default_kwargs=self.get_default_gen_args()
default_kwargs.update(kwargs)
args_process.postprocess(default_kwargs)
args_process.build_args(default_kwargs)
stopping_criteria = default_kwargs.pop('stopping_criteria', None)
generation_config = GenerationConfig(**default_kwargs)
for response in self.get_model().chat(tokenizer=self.tokenizer,
Expand Down Expand Up @@ -156,11 +155,10 @@ def chat_stream(self,messages: List[Dict], **kwargs):


def chat(self,messages: List[Dict], **kwargs):
args_process = GenerateProcess(self)
args_process.preprocess(kwargs)
args_process = GenArgs(kwargs, self)
default_kwargs = self.get_default_gen_args()
default_kwargs.update(kwargs)
args_process.postprocess(default_kwargs)
args_process.build_args(default_kwargs)
stopping_criteria = default_kwargs.pop('stopping_criteria', None)
generation_config = GenerationConfig(**default_kwargs)
response = self.get_model().chat(tokenizer=self.tokenizer,
Expand All @@ -177,6 +175,7 @@ def chat(self,messages: List[Dict], **kwargs):


def embedding(self, query, **kwargs):
args_process = GenArgs(kwargs, self)
model = self.get_model()
inputs = self.tokenizer(query, return_tensors="pt")
inputs = inputs.to(model.device)
Expand Down
13 changes: 6 additions & 7 deletions serving/model_handler/baichuan2_7b/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from aigc_zoo.model_zoo.baichuan.baichuan2_7b.llm_model import MyTransformer,BaichuanConfig,BaichuanTokenizer,\
MyBaichuanForCausalLM,PetlArguments,PetlModel
from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, \
load_lora_config,GenerateProcess,WorkMode
load_lora_config,GenArgs,WorkMode
from serving.prompt import *

class NN_DataHelper(DataHelper):pass
Expand Down Expand Up @@ -125,12 +125,11 @@ def get_default_gen_args(self):
return default_kwargs

def chat_stream(self,messages: List[Dict], **kwargs):
args_process = GenerateProcess(self,is_stream=True)
args_process.preprocess(kwargs)
args_process = GenArgs(kwargs, self, is_stream=True)
chunk = args_process.chunk
default_kwargs=self.get_default_gen_args()
default_kwargs.update(kwargs)
args_process.postprocess(default_kwargs)
args_process.build_args(default_kwargs)
stopping_criteria = default_kwargs.pop('stopping_criteria', None)
generation_config = GenerationConfig(**default_kwargs)
for response in self.get_model().chat(tokenizer=self.tokenizer,
Expand Down Expand Up @@ -158,11 +157,10 @@ def chat_stream(self,messages: List[Dict], **kwargs):


def chat(self,messages: List[Dict], **kwargs):
args_process = GenerateProcess(self)
args_process.preprocess(kwargs)
args_process = GenArgs(kwargs, self)
default_kwargs = self.get_default_gen_args()
default_kwargs.update(kwargs)
args_process.postprocess(default_kwargs)
args_process.build_args(default_kwargs)
stopping_criteria = default_kwargs.pop('stopping_criteria', None)
generation_config = GenerationConfig(**default_kwargs)
response = self.get_model().chat(tokenizer=self.tokenizer,
Expand All @@ -177,6 +175,7 @@ def chat(self,messages: List[Dict], **kwargs):


def embedding(self, query, **kwargs):
args_process = GenArgs(kwargs, self)
model = self.get_model()
inputs = self.tokenizer(query, return_tensors="pt")
inputs = inputs.to(model.device)
Expand Down
17 changes: 8 additions & 9 deletions serving/model_handler/baichuan_13b/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers import HfArgumentParser, BitsAndBytesConfig, GenerationConfig
from aigc_zoo.model_zoo.baichuan.baichuan_13b.llm_model import MyTransformer,BaichuanConfig,BaichuanTokenizer,\
MyBaichuanForCausalLM,PetlArguments,PetlModel
from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenerateProcess,WorkMode
from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenArgs,WorkMode
from serving.prompt import *


Expand Down Expand Up @@ -123,12 +123,11 @@ def get_default_gen_args(self):
return default_kwargs

def chat_stream(self,messages: List[Dict], **kwargs):
args_process = GenerateProcess(self,is_stream=True)
args_process.preprocess(kwargs)
args_process = GenArgs(kwargs, self, is_stream=True)
chunk = args_process.chunk
default_kwargs = self.get_default_gen_args()
default_kwargs.update(kwargs)
args_process.postprocess(default_kwargs)
args_process.build_args(default_kwargs)
stopping_criteria = default_kwargs.pop('stopping_criteria', None)
generation_config = GenerationConfig(**default_kwargs)
for response in self.get_model().chat(tokenizer=self.tokenizer,
Expand All @@ -155,11 +154,10 @@ def chat_stream(self,messages: List[Dict], **kwargs):


def chat(self,messages: List[Dict], **kwargs):
args_process = GenerateProcess(self)
args_process.preprocess(kwargs)
args_process = GenArgs(kwargs, self)
default_kwargs = self.get_default_gen_args()
default_kwargs.update(kwargs)
args_process.postprocess(default_kwargs)
args_process.build_args(default_kwargs)
stopping_criteria = default_kwargs.pop('stopping_criteria', None)
generation_config = GenerationConfig(**default_kwargs)
response = self.get_model().chat(tokenizer=self.tokenizer,
Expand All @@ -174,10 +172,10 @@ def chat(self,messages: List[Dict], **kwargs):
})

def generate(self, messages: List[Dict], **kwargs):
args_process = GenerateProcess(self)
args_process = GenArgs(kwargs, self)
default_kwargs = self.get_default_gen_args()
default_kwargs.update(kwargs)
args_process.postprocess(default_kwargs)
args_process.build_args(default_kwargs)
query = args_process.get_chat_info(messages, chat_format="generate")
response = self.get_model().generate(query=query, **kwargs)
return CompletionResult(result={
Expand All @@ -187,6 +185,7 @@ def generate(self, messages: List[Dict], **kwargs):


def embedding(self, query, **kwargs):
args_process = GenArgs(kwargs, self)
model = self.get_model()
inputs = self.tokenizer(query, return_tensors="pt")
inputs = inputs.to(model.device)
Expand Down
17 changes: 8 additions & 9 deletions serving/model_handler/baichuan_7b/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from transformers import HfArgumentParser, BitsAndBytesConfig, GenerationConfig
from aigc_zoo.model_zoo.baichuan.baichuan_7b.llm_model import MyTransformer,BaiChuanConfig,BaiChuanTokenizer,PetlArguments,PetlModel
from aigc_zoo.generator_utils.generator_llm import Generate
from serving.model_handler.base import EngineAPI_Base,CompletionResult, LoraModelState, load_lora_config,GenerateProcess,WorkMode
from serving.model_handler.base import EngineAPI_Base,CompletionResult, LoraModelState, load_lora_config,GenArgs,WorkMode
from serving.prompt import *

class NN_DataHelper(DataHelper):pass
Expand Down Expand Up @@ -134,12 +134,11 @@ def get_default_gen_args(self):
return default_kwargs

def chat_stream(self,messages: List[Dict], **kwargs):
args_process = GenerateProcess(self,is_stream=True)
args_process.preprocess(kwargs)
args_process = GenArgs(kwargs, self, is_stream=True)
chunk = args_process.chunk
default_kwargs= self.get_default_gen_args()
default_kwargs.update(kwargs)
generation_config = GenerationConfig(**args_process.postprocess(default_kwargs))
generation_config = GenerationConfig(**args_process.build_args(default_kwargs))
query, history = args_process.get_chat_info(messages)
prompt = get_chat_default(self.tokenizer, query, history)

Expand Down Expand Up @@ -177,13 +176,12 @@ def stream_generator():


def chat(self,messages: List[Dict], **kwargs):
args_process = GenerateProcess(self)
args_process.preprocess(kwargs)
args_process = GenArgs(kwargs, self)
default_kwargs = self.get_default_gen_args()
default_kwargs.update(kwargs)
query, history = args_process.get_chat_info(messages)
prompt = get_chat_default(self.tokenizer, query, history)
response = self.gen_core.generate(query=prompt, **args_process.postprocess(default_kwargs))
response = self.gen_core.generate(query=prompt, **args_process.build_args(default_kwargs))
response = args_process.postprocess_response(response, **kwargs)
# history = history + [(query, response)]
return CompletionResult(result={
Expand All @@ -192,10 +190,10 @@ def chat(self,messages: List[Dict], **kwargs):
})

def generate(self,messages: List[Dict],**kwargs):
args_process = GenerateProcess(self)
args_process = GenArgs(kwargs, self)
default_kwargs = self.get_default_gen_args()
default_kwargs.update(kwargs)
args_process.postprocess(default_kwargs)
args_process.build_args(default_kwargs)
query = args_process.get_chat_info(messages,chat_format="generate")
response = self.gen_core.generate(query=query, **kwargs)
return CompletionResult(result={
Expand All @@ -204,6 +202,7 @@ def generate(self,messages: List[Dict],**kwargs):
})

def embedding(self, query, **kwargs):
args_process = GenArgs(kwargs, self)
model = self.get_model()
inputs = self.tokenizer(query, return_tensors="pt")
inputs = inputs.to(model.device)
Expand Down
2 changes: 1 addition & 1 deletion serving/model_handler/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@

from .infer import EngineAPI_Base,flat_input,CompletionResult
from .data_define import ChunkData, LoraModelState, WorkMode
from .data_process import GenerateProcess
from .data_process import GenArgs
from .loaders import load_lora_config
from .utils import is_quantization_bnb
28 changes: 24 additions & 4 deletions serving/model_handler/base/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,22 +156,42 @@ def _calc_stopped_samples(self, input_ids: torch.LongTensor) -> bool:
del ids
return False

class GenerateProcess:
def __init__(self,this_obj,is_stream=False):
class GenArgs:
def __init__(self,args_dict:Dict, this_obj,is_stream=False):
if args_dict is None:
args_dict = {}

self.tokenizer: Optional[PreTrainedTokenizer] = this_obj.tokenizer
self.config: Optional[PretrainedConfig] = this_obj.config
self.is_stream = is_stream
self.chunk: Optional[ChunkData] = None
self.this_obj = this_obj

def preprocess(self, args_dict: dict):
# support seed
self.multinomial_fn = torch.multinomial
self.__preprocess(args_dict)
def __del__(self):
# restore
if torch.multinomial != self.multinomial_fn:
torch.multinomial = self.multinomial_fn

def __preprocess(self, args_dict):
if self.is_stream:
nchar = args_dict.pop('nchar',1)
gtype = args_dict.pop('gtype',"total")
self.chunk = ChunkData(nchar=nchar, stop=args_dict.get('stop', None), mode=gtype)

seed = args_dict.pop('seed',None)

#进程隔离,互不影响
if isinstance(seed,int):
device = self.this_obj.get_model().device
torch.multinomial = lambda *args, **kwargs: self.multinomial_fn(*args,
generator=torch.Generator(device=device).manual_seed(seed),
**kwargs)
return args_dict

def postprocess(self, args_dict):
def build_args(self, args_dict):
stop = args_dict.pop('stop',None)
if stop is None:
return args_dict
Expand Down
17 changes: 8 additions & 9 deletions serving/model_handler/bluelm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from aigc_zoo.model_zoo.bluelm.llm_model import MyBlueLMForCausalLM,BlueLMTokenizer,BlueLMConfig
from aigc_zoo.model_zoo.llm.llm_model import MyTransformer,PetlArguments,PetlModel,AutoConfig
from aigc_zoo.generator_utils.generator_llm import Generate
from serving.model_handler.base import EngineAPI_Base,CompletionResult, LoraModelState, load_lora_config, GenerateProcess,WorkMode,ChunkData
from serving.model_handler.base import EngineAPI_Base,CompletionResult, LoraModelState, load_lora_config, GenArgs,WorkMode,ChunkData
from serving.prompt import *

class NN_DataHelper(DataHelper):pass
Expand Down Expand Up @@ -147,11 +147,10 @@ def get_default_gen_args(self):
return default_kwargs

def chat_stream(self,messages: List[Dict], **kwargs):
args_process = GenerateProcess(self,is_stream=True)
args_process.preprocess(kwargs)
args_process = GenArgs(kwargs, self, is_stream=True)
default_kwargs = self.get_default_gen_args()
default_kwargs.update(kwargs)
args_process.postprocess(default_kwargs)
args_process.build_args(default_kwargs)
prefix,query, history = args_process.get_chat_info_with_system(messages)
prompt = get_chat_bluelm(self.tokenizer, query, history=history, prefix=prefix)
skip_word_list = default_kwargs.get('eos_token_id', None) or [self.tokenizer.eos_token_id]
Expand All @@ -164,11 +163,10 @@ def chat_stream(self,messages: List[Dict], **kwargs):


def chat(self,messages: List[Dict], **kwargs):
args_process = GenerateProcess(self)
args_process.preprocess(kwargs)
args_process = GenArgs(kwargs, self)
default_kwargs = self.get_default_gen_args()
default_kwargs.update(kwargs)
args_process.postprocess(default_kwargs)
args_process.build_args(default_kwargs)
prefix,query, history = args_process.get_chat_info_with_system(messages)
prompt = get_chat_bluelm(self.tokenizer, query, history=history, prefix=prefix)
response = self.gen_core.generate(query=prompt, **default_kwargs)
Expand All @@ -180,10 +178,10 @@ def chat(self,messages: List[Dict], **kwargs):


def generate(self,messages: List[Dict],**kwargs):
args_process = GenerateProcess(self)
args_process = GenArgs(kwargs, self)
default_kwargs = self.get_default_gen_args()
default_kwargs.update(kwargs)
args_process.postprocess(default_kwargs)
args_process.build_args(default_kwargs)
query = args_process.get_chat_info(messages,chat_format="generate")
response = self.gen_core.generate(query=query, **default_kwargs)
return CompletionResult(result={
Expand All @@ -192,6 +190,7 @@ def generate(self,messages: List[Dict],**kwargs):
})

def embedding(self, query, **kwargs):
args_process = GenArgs(kwargs, self)
model = self.get_model()
inputs = self.tokenizer(query, return_tensors="pt")
inputs = inputs.to(model.device)
Expand Down
17 changes: 8 additions & 9 deletions serving/model_handler/chatglm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers import HfArgumentParser
from aigc_zoo.model_zoo.chatglm.llm_model import MyTransformer, ChatGLMTokenizer, PetlArguments, setup_model_profile, \
ChatGLMConfig,PetlModel
from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenerateProcess,WorkMode
from serving.model_handler.base import EngineAPI_Base,CompletionResult,LoraModelState, load_lora_config, GenArgs,WorkMode
from serving.prompt import *

class NN_DataHelper(DataHelper):pass
Expand Down Expand Up @@ -148,12 +148,11 @@ def get_default_gen_args(self):
return default_kwargs

def chat_stream(self, messages: List[Dict], **kwargs):
args_process = GenerateProcess(self,is_stream=True)
args_process.preprocess(kwargs)
args_process = GenArgs(kwargs, self, is_stream=True)
chunk = args_process.chunk
default_kwargs=self.get_default_gen_args()
default_kwargs.update(kwargs)
args_process.postprocess(default_kwargs)
args_process.build_args(default_kwargs)
query,history = args_process.get_chat_info(messages)
for response, history in self.model.stream_chat(self.tokenizer, query=query,history=history, **kwargs):
chunk.step(response)
Expand All @@ -175,11 +174,10 @@ def chat_stream(self, messages: List[Dict], **kwargs):
}, complete=False)

def chat(self,messages: List[Dict], **kwargs):
args_process = GenerateProcess(self)
args_process.preprocess(kwargs)
args_process = GenArgs(kwargs, self)
default_kwargs=self.get_default_gen_args()
default_kwargs.update(kwargs)
args_process.postprocess(default_kwargs)
args_process.build_args(default_kwargs)
query, history = args_process.get_chat_info(messages)
response, history = self.model.chat(self.tokenizer, query=query,history=history, **default_kwargs)
response = args_process.postprocess_response(response, **kwargs)
Expand All @@ -189,10 +187,10 @@ def chat(self,messages: List[Dict], **kwargs):
})

def generate(self,messages: List[Dict],**kwargs):
args_process = GenerateProcess(self)
args_process = GenArgs(kwargs, self)
default_kwargs=self.get_default_gen_args()
default_kwargs.update(kwargs)
args_process.postprocess(default_kwargs)
args_process.build_args(default_kwargs)
query = args_process.get_chat_info(messages,chat_format="generate")
output,_ = self.model.chat(self.tokenizer, query=query,**default_kwargs)
output_scores = default_kwargs.get('output_scores', False)
Expand All @@ -205,6 +203,7 @@ def generate(self,messages: List[Dict],**kwargs):
})

def embedding(self, query, **kwargs):
args_process = GenArgs(kwargs, self)
model = self.get_model()
inputs = self.tokenizer(query, return_tensors="pt")
inputs = inputs.to(model.device)
Expand Down
Loading

0 comments on commit d93761a

Please sign in to comment.