diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index e1800f29b0af..47f0ba0afc0e 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -558,6 +558,10 @@ def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True) missing_keys = missing_keys.append(sub_missing_keys) if strict: + if len(error_msgs) > 0: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) + ) if len(unexpected_keys) > 0: error_msgs = "Unexpected key(s) in state_dict: {}. ".format( ", ".join('"{}"'.format(k) for k in unexpected_keys)