diff --git a/onmt/translate/translation_server.py b/onmt/translate/translation_server.py index 597a8e04ef..70aad3fbb1 100644 --- a/onmt/translate/translation_server.py +++ b/onmt/translate/translation_server.py @@ -81,15 +81,16 @@ class CTranslate2Translator(object): """ def __init__(self, model_path, device, device_index, batch_size, - beam_size, n_best, target_prefix=False, preload=False): + beam_size, n_best, target_prefix=False, preload=False, + inter_threads=1, intra_threads=1, compute_type="default"): import ctranslate2 self.translator = ctranslate2.Translator( model_path, device=device, device_index=device_index, - inter_threads=1, - intra_threads=1, - compute_type="default") + inter_threads=inter_threads, + intra_threads=intra_threads, + compute_type=compute_type) self.batch_size = batch_size self.beam_size = beam_size self.n_best = n_best @@ -155,7 +156,10 @@ def start(self, config_file): 'custom_opt': conf.get('custom_opt', None), 'on_timeout': conf.get('on_timeout', None), 'model_root': conf.get('model_root', self.models_root), - 'ct2_model': conf.get('ct2_model', None) + 'ct2_model': conf.get('ct2_model', None), + 'inter_threads': conf.get('inter_threads', None), + 'intra_threads': conf.get('intra_threads', None), + 'compute_type': conf.get('compute_type', None) } kwargs = {k: v for (k, v) in kwargs.items() if v is not None} model_id = conf.get("id", None) @@ -258,11 +262,16 @@ class ServerModel(object): timeout (see :func:`do_timeout()`.) model_root (str): Path to the model directory it must contain the model and tokenizer file + ct2_model (str): Path to the CTranslate2 model directory + inter_threads (int): Maximum number of CPU translations executed in parallel + intra_threads (int): Number of OpenMP threads that is used per translation + compute_type (str): The type used for computation """ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None, postprocess_opt=None, custom_opt=None, load=False, timeout=-1, - on_timeout="to_cpu", model_root="./", ct2_model=None): + on_timeout="to_cpu", model_root="./", ct2_model=None, + inter_threads=1, intra_threads=1, compute_type="default"): self.model_root = model_root self.opt = self.parse_opt(opt) self.custom_opt = custom_opt @@ -276,6 +285,9 @@ def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None, self.ct2_model = os.path.join(model_root, ct2_model) \ if ct2_model is not None else None + self.inter_threads = inter_threads + self.intra_threads = intra_threads + self.compute_type = compute_type self.unload_timer = None self.user_opt = opt @@ -393,7 +405,10 @@ def load(self, preload=False): beam_size=self.opt.beam_size, n_best=self.opt.n_best, target_prefix=self.opt.tgt_prefix, - preload=preload) + preload=preload, + inter_threads=self.inter_threads, + intra_threads=self.intra_threads, + compute_type=self.compute_type) else: self.translator = build_translator( self.opt, report_score=False,