From d060bdb2ba9f6e5536a72090910ac8a32b8bc3ca Mon Sep 17 00:00:00 2001 From: moriyama naoto Date: Fri, 13 Aug 2021 13:22:46 +0900 Subject: [PATCH 1/2] add ct2 server parameters --- onmt/translate/translation_server.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/onmt/translate/translation_server.py b/onmt/translate/translation_server.py index 597a8e04ef..660654037e 100644 --- a/onmt/translate/translation_server.py +++ b/onmt/translate/translation_server.py @@ -81,15 +81,16 @@ class CTranslate2Translator(object): """ def __init__(self, model_path, device, device_index, batch_size, - beam_size, n_best, target_prefix=False, preload=False): + beam_size, n_best, target_prefix=False, preload=False, + inter_threads=1, intra_threads=1, compute_type="default"): import ctranslate2 self.translator = ctranslate2.Translator( model_path, device=device, device_index=device_index, - inter_threads=1, - intra_threads=1, - compute_type="default") + inter_threads=inter_threads, + intra_threads=intra_threads, + compute_type=compute_type) self.batch_size = batch_size self.beam_size = beam_size self.n_best = n_best @@ -155,7 +156,10 @@ def start(self, config_file): 'custom_opt': conf.get('custom_opt', None), 'on_timeout': conf.get('on_timeout', None), 'model_root': conf.get('model_root', self.models_root), - 'ct2_model': conf.get('ct2_model', None) + 'ct2_model': conf.get('ct2_model', None), + 'inter_threads': conf.get('inter_threads', None), + 'intra_threads': conf.get('intra_threads', None), + 'compute_type': conf.get('compute_type', None) } kwargs = {k: v for (k, v) in kwargs.items() if v is not None} model_id = conf.get("id", None) @@ -262,7 +266,8 @@ class ServerModel(object): def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None, postprocess_opt=None, custom_opt=None, load=False, timeout=-1, - on_timeout="to_cpu", model_root="./", ct2_model=None): + on_timeout="to_cpu", model_root="./", ct2_model=None, + inter_threads=1, intra_threads=1, compute_type="default"): self.model_root = model_root self.opt = self.parse_opt(opt) self.custom_opt = custom_opt @@ -276,6 +281,9 @@ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None, self.ct2_model = os.path.join(model_root, ct2_model) \ if ct2_model is not None else None + self.inter_threads = inter_threads + self.intra_threads = intra_threads + self.compute_type = compute_type self.unload_timer = None self.user_opt = opt @@ -393,7 +401,10 @@ def load(self, preload=False): beam_size=self.opt.beam_size, n_best=self.opt.n_best, target_prefix=self.opt.tgt_prefix, - preload=preload) + preload=preload, + inter_threads=self.inter_threads, + intra_threads=self.intra_threads, + compute_type=self.compute_type) else: self.translator = build_translator( self.opt, report_score=False, From 8fb00ed721212736f46a045c5e0cc11cd5fac5d5 Mon Sep 17 00:00:00 2001 From: moriyama naoto Date: Fri, 13 Aug 2021 13:45:58 +0900 Subject: [PATCH 2/2] add doc --- onmt/translate/translation_server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onmt/translate/translation_server.py b/onmt/translate/translation_server.py index 660654037e..70aad3fbb1 100644 --- a/onmt/translate/translation_server.py +++ b/onmt/translate/translation_server.py @@ -262,6 +262,10 @@ class ServerModel(object): timeout (see :func:`do_timeout()`.) model_root (str): Path to the model directory it must contain the model and tokenizer file + ct2_model (str): Path to the CTranslate2 model directory + inter_threads (int): Maximum number of CPU translations executed in parallel + intra_threads (int): Number of OpenMP threads that is used per translation + compute_type (str): The type used for computation """ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None,