From c5c84af7294e6ad1595db4203721ea476d4faef1 Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Sat, 11 Nov 2023 17:54:24 +0100 Subject: [PATCH] Fixkeybuffer (#2512) * fix buffers loading (awq) --- onmt/models/model.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/onmt/models/model.py b/onmt/models/model.py index 07fe9edc1a..e908a6ff16 100644 --- a/onmt/models/model.py +++ b/onmt/models/model.py @@ -46,7 +46,6 @@ def count_parameters(self, log=print): raise NotImplementedError def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset): - if name.split(".")[-1] in [ "linear_keys", "linear_values", @@ -73,7 +72,7 @@ def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset) row_slice_start:row_slice_end, ].size() ), "An error in model's partition and checkpoint's slice was detected" - if param_name in buf_list: + if name + "." + param_name in buf_list: module.register_buffer( param_name, ckpt_t[ @@ -90,7 +89,7 @@ def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset) assert ( param.data.size() == ckpt_t[col_slice_start:col_slice_end].size() ), "An error in model's partition and checkpoint's slice was detected" - if param_name in buf_list: + if name + "." + param_name in buf_list: module.register_buffer( param_name, ckpt_t[col_slice_start:col_slice_end] ) @@ -120,9 +119,9 @@ def load_state_dict( if device == torch.device("cpu"): offset = 0 buf_list = [] + for buf_name, buf in self.named_buffers(): + buf_list.append(buf_name) for name, module in self.named_modules(): - for buf_name, buf in module.named_buffers(): - buf_list.append(buf_name) named_buf_and_param = list(module.named_buffers()) + list( module.named_parameters() ) @@ -205,9 +204,9 @@ def load_safe_state_dict( if device == torch.device("cpu"): offset = 0 buf_list = [] + for buf_name, buf in self.named_buffers(): + buf_list.append(buf_name) for name, module in self.named_modules(): - for buf_name, buf in module.named_buffers(): - buf_list.append(buf_name) named_buf_and_param = list(module.named_buffers()) + list( module.named_parameters() )